From f77c5803f4d61099e572be8d877b1c4a4d6087cd Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期三, 10 五月 2023 12:02:06 +0800
Subject: [PATCH] Merge pull request #485 from alibaba-damo-academy/main
---
funasr/runtime/grpc/Readme.md | 51
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py | 24
docs/model_zoo/modelscope_models.md | 126
funasr/runtime/onnxruntime/src/tokenizer.h | 4
funasr/tasks/sa_asr.py | 623 ++
egs/alimeeting/sa-asr/local/format_wav_scp.sh | 142
egs/alimeeting/sa-asr/local/validate_data_dir.sh | 404 +
funasr/bin/vad_inference.py | 13
funasr/runtime/websocket/readme.md | 99
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/demo.py | 11
tests/test_asr_inference_pipeline.py | 16
funasr/runtime/onnxruntime/src/resample.cpp | 2
funasr/bin/sa_asr_train.py | 47
funasr/runtime/onnxruntime/src/audio.cpp | 42
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py | 4
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py | 12
funasr/models/encoder/conformer_encoder.py | 4
egs/alimeeting/sa-asr/local/gen_oracle_embedding.py | 70
funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py | 45
funasr/models/pooling/statistic_pooling.py | 4
funasr/runtime/python/websocket/README.md | 109
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py | 4
funasr/runtime/grpc/paraformer-server.cc | 50
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.sh | 1
egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl | 38
funasr/export/models/CT_Transformer.py | 4
docs/academic_recipe/vad_recipe.md | 129
funasr/utils/timestamp_tools.py | 38
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md | 1
funasr/bin/build_trainer.py | 3
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.sh | 1
funasr/runtime/python/websocket/ws_server_2pass.py | 182
egs/alimeeting/sa-asr/local/apply_map.pl | 97
funasr/models/decoder/rnnt_decoder.py | 12
funasr/runtime/onnxruntime/src/fsmn-vad.h | 9
docs/reference/papers.md | 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh | 1
funasr/runtime/websocket/websocketsrv.cpp | 158
docs/academic_recipe/lm_recipe.md | 128
funasr/bin/asr_train.py | 7
funasr/bin/asr_inference_rnnt.py | 19
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py | 4
egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh | 591 ++
egs/alimeeting/sa-asr/local/download_xvector_model.py | 6
egs_modelscope/punctuation/TEMPLATE/infer.sh | 66
egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py | 22
egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py | 235
docs/modelscope_pipeline/quick_start.md | 2
egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py | 127
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml | 87
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh | 0
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp | 34
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py | 4
egs_modelscope/tp/TEMPLATE/infer.py | 0
funasr/runtime/onnxruntime/src/offline-stream.cpp | 64
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml | 115
egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py | 59
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/utils | 1
egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml | 6
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py | 4
funasr/models/e2e_asr_contextual_paraformer.py | 372 +
funasr/runtime/onnxruntime/src/paraformer.cpp | 80
egs/alimeeting/sa-asr/local/copy_data_dir.sh | 145
funasr/runtime/python/websocket/parse_args.py | 7
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py | 4
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py | 4
funasr/runtime/onnxruntime/src/alignedmem.cpp | 3
funasr/torch_utils/load_pretrained_model.py | 2
funasr/runtime/onnxruntime/include/offline-stream.h | 30
egs_modelscope/speaker_verification/TEMPLATE/README.md | 12
egs/alimeeting/sa-asr/asr_local.sh | 1572 ++++++
egs/alimeeting/sa-asr/local/process_text_spk_merge.py | 55
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/infer.py | 4
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py | 1
funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp | 143
funasr/train/trainer.py | 14
setup.py | 7
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh | 116
egs/alimeeting/sa-asr/local/data/get_reco2dur.sh | 143
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 6
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py | 39
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py | 4
egs_modelscope/vad/TEMPLATE/README.md | 38
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py | 4
funasr/runtime/onnxruntime/src/tensor.h | 4
funasr/bin/punctuation_infer_vadrealtime.py | 4
egs_modelscope/punctuation/TEMPLATE/infer.py | 23
funasr/runtime/onnxruntime/src/predefine-coe.h | 3
egs/alimeeting/sa-asr/local/process_text_id.py | 24
funasr/runtime/onnxruntime/include/com-define.h | 42
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo.py | 4
egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py | 167
egs_modelscope/tp/TEMPLATE/infer.sh | 2
egs/alimeeting/sa-asr/utils | 1
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py | 4
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo_online.py | 4
funasr/datasets/large_datasets/dataset.py | 37
egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl | 27
funasr/bin/asr_inference.py | 27
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py | 4
docs/installation/installation.md | 0
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh | 6
funasr/runtime/websocket/websocketmain.cpp | 149
funasr/modules/nets_utils.py | 35
.gitignore | 3
funasr/models/e2e_sa_asr.py | 520 ++
funasr/runtime/onnxruntime/src/tokenizer.cpp | 7
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py | 4
egs/alimeeting/sa-asr/README.md | 79
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/infer.py | 4
funasr/datasets/large_datasets/utils/padding.py | 58
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py | 1
funasr/bin/sa_asr_inference.py | 687 ++
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py | 0
funasr/runtime/onnxruntime/src/commonfunc.h | 9
docs/academic_recipe/punc_recipe.md | 129
egs/alimeeting/sa-asr/local/data/get_utt2dur.sh | 135
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/infer.py | 4
docs/reference/application.md | 0
funasr/modules/attention.py | 37
funasr/runtime/python/websocket/ws_server_offline.py | 150
docs/README.md | 19
egs_modelscope/asr/paraformer/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/infer.py | 4
funasr/runtime/onnxruntime/src/util.cpp | 3
funasr/runtime/onnxruntime/include/funasrruntime.h | 88
docs/installation/docker.md | 0
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh | 1
docs/reference/FQA.md | 0
egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py | 68
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py | 4
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py | 4
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md | 1
funasr/runtime/onnxruntime/src/punc-model.cpp | 22
egs/alimeeting/sa-asr/local/format_wav_scp.py | 243
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py | 6
egs/alimeeting/sa-asr/local/validate_text.pl | 136
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo_online.py | 4
funasr/losses/label_smoothing_loss.py | 46
docs/modelscope_pipeline/itn_pipeline.md | 63
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md | 1
funasr/runtime/websocket/CMakeLists.txt | 64
funasr/bin/vad_inference_online.py | 5
funasr/models/frontend/default.py | 11
egs_modelscope/asr/TEMPLATE/README.md | 58
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py | 4
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md | 1
egs_modelscope/punctuation/TEMPLATE/utils | 1
egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py | 4
funasr/runtime/onnxruntime/src/ct-transformer.cpp | 11
funasr/runtime/onnxruntime/readme.md | 126
funasr/runtime/onnxruntime/src/model.cpp | 17
funasr/runtime/onnxruntime/src/online-feature.h | 7
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh | 129
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md | 1
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh | 162
egs/alimeeting/sa-asr/path.sh | 5
funasr/runtime/grpc/paraformer-server.h | 2
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py | 1
funasr/bin/asr_inference_paraformer.py | 3
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py | 4
funasr/runtime/onnxruntime/src/precomp.h | 11
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md | 1
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.sh | 1
funasr/runtime/onnxruntime/include/vad-model.h | 29
funasr/models/encoder/sanm_encoder.py | 2
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py | 4
funasr/tasks/abs_task.py | 22
funasr/utils/postprocess_utils.py | 2
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh | 1
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py | 16
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md | 1
funasr/runtime/onnxruntime/src/e2e-vad.h | 72
funasr/runtime/python/grpc/proto/paraformer.proto | 14
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py | 4
docs/academic_recipe/sd_recipe.md | 129
funasr/runtime/onnxruntime/src/vad-model.cpp | 24
docs/index.rst | 18
funasr/runtime/onnxruntime/include/model.h | 8
funasr/runtime/websocket/websocketsrv.h | 93
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh | 50
egs/alimeeting/sa-asr/run.sh | 50
funasr/runtime/onnxruntime/src/ct-transformer.h | 6
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py | 4
egs_modelscope/speaker_diarization/TEMPLATE/README.md | 10
funasr/models/decoder/transformer_decoder.py | 428 +
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py | 4
docs/reference/build_task.md | 0
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py | 4
funasr/runtime/onnxruntime/include/audio.h | 9
funasr/bin/asr_inference_paraformer_streaming.py | 53
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py | 4
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py | 4
funasr/runtime/onnxruntime/CMakeLists.txt | 3
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py | 4
funasr/models/e2e_asr_transducer.py | 8
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py | 1
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md | 1
egs/alimeeting/sa-asr/local/combine_data.sh | 146
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py | 37
egs/alimeeting/sa-asr/local/fix_data_dir.sh | 215
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/infer.py | 4
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml | 29
docs/runtime/websocket_cpp.md | 1
funasr/runtime/websocket/websocketclient.cpp | 221
funasr/runtime/onnxruntime/src/paraformer.h | 24
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.sh | 1
funasr/runtime/onnxruntime/include/punc-model.h | 20
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py | 4
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp | 58
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py | 4
funasr/runtime/onnxruntime/src/CMakeLists.txt | 4
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py | 1
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md | 1
funasr/runtime/python/grpc/Readme.md | 2
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 10
funasr/runtime/onnxruntime/src/resample.h | 5
funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp | 98
docs/modelscope_pipeline/punc_pipeline.md | 1
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py | 1
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py | 4
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/infer.py | 4
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md | 264
egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py | 86
egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh | 29
funasr/runtime/onnxruntime/src/util.h | 4
README.md | 12
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/infer.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py | 4
funasr/runtime/onnxruntime/src/alignedmem.h | 2
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py | 30
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py | 1
funasr/datasets/large_datasets/utils/hotword_utils.py | 32
funasr/datasets/large_datasets/utils/tokenize.py | 8
funasr/runtime/onnxruntime/src/vocab.h | 2
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py | 4
egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py | 158
docs/academic_recipe/sv_recipe.md | 129
funasr/fileio/sound_scp.py | 6
funasr/version.txt | 2
egs_modelscope/vad/TEMPLATE/infer.sh | 2
funasr/runtime/onnxruntime/src/funasrruntime.cpp | 362 +
funasr/runtime/python/websocket/ws_server_online.py | 51
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh | 1
egs/alimeeting/sa-asr/local/compute_cpcer.py | 91
egs/alimeeting/sa-asr/local/text_normalize.pl | 38
egs_modelscope/tp/TEMPLATE/README.md | 42
funasr/runtime/onnxruntime/src/online-feature.cpp | 4
egs/alimeeting/sa-asr/local/data/split_data.sh | 160
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py | 1
egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py | 4
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md | 1
funasr/runtime/python/onnxruntime/setup.py | 2
funasr/modules/beam_search/beam_search_sa_asr.py | 525 ++
egs/alimeeting/sa-asr/local/text_format.pl | 14
egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py | 4
funasr/tasks/asr.py | 5
egs_modelscope/punctuation/TEMPLATE/README.md | 48
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh | 1
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py | 4
docs/model_zoo/huggingface_models.md | 0
/dev/null | 210
egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py | 4
funasr/runtime/python/websocket/ws_client.py | 127
funasr/modules/repeat.py | 4
funasr/bin/asr_inference_launch.py | 13
funasr/runtime/onnxruntime/src/vocab.cpp | 3
272 files changed, 13,499 insertions(+), 1,702 deletions(-)
diff --git a/.gitignore b/.gitignore
index 33b8c39..c4b031f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,4 +16,5 @@
.egg*
dist
build
-funasr.egg-info
\ No newline at end of file
+funasr.egg-info
+docs/_build
\ No newline at end of file
diff --git a/README.md b/README.md
index 665f425..64d6d89 100644
--- a/README.md
+++ b/README.md
@@ -13,10 +13,10 @@
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
-| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
+| [**Tutorial_CN**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
-| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md)
+| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md)
| [**Contact**](#contact)
| [**M2MET2.0 Challenge**](https://github.com/alibaba-damo-academy/FunASR#multi-channel-multi-party-meeting-transcription-20-m2met20-challenge)
@@ -28,7 +28,7 @@
## Highlights
- FunASR supports speech recognition(ASR), Multi-talker ASR, Voice Activity Detection(VAD), Punctuation Restoration, Language Models, Speaker Verification and Speaker diarization.
-- We have released large number of academic and industrial pretrained models on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition)
+- We have released large number of academic and industrial pretrained models on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition), ref to [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md)
- The pretrained model [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) obtains the best performance on many tasks in [SpeechIO leaderboard](https://github.com/SpeechColab/Leaderboard)
- FunASR supplies a easy-to-use pipeline to finetune pretrained models from [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition)
- Compared to [Espnet](https://github.com/espnet/espnet) framework, the training speed of large-scale datasets in FunASR is much faster owning to the optimized dataloader.
@@ -60,12 +60,8 @@
# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
-For more details, please ref to [installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html)
+For more details, please ref to [installation](https://alibaba-damo-academy.github.io/FunASR/en/installation/installation.html)
-[//]: # ()
-[//]: # (## Usage)
-
-[//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)))
## Contact
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 0000000..4e16b04
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,19 @@
+# FunASR document generation
+
+## Generate HTML
+For convenience, we provide users with the ability to generate local HTML manually.
+
+First, you should install the following packages, which is required for building HTML:
+```sh
+conda activate funasr
+pip install requests sphinx nbsphinx sphinx_markdown_tables sphinx_rtd_theme recommonmark
+```
+
+Then you can generate HTML manually.
+
+```sh
+cd docs
+make html
+```
+
+The generated files are all contained in the "FunASR/docs/_build" directory. You can access the FunASR documentation by simply opening the "html/index.html" file in your browser from this directory.
\ No newline at end of file
diff --git a/docs/academic_recipe/lm_recipe.md b/docs/academic_recipe/lm_recipe.md
index f82a6fe..730e27c 100644
--- a/docs/academic_recipe/lm_recipe.md
+++ b/docs/academic_recipe/lm_recipe.md
@@ -1,129 +1,3 @@
# Speech Recognition
-Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
-## Overall Introduction
-We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
-- `CUDA_VISIBLE_DEVICES`: visible gpu list
-- `gpu_num`: the number of GPUs used for training
-- `gpu_inference`: whether to use GPUs for decoding
-- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
-- `data_aishell`: the raw path of AISHELL-1 dataset
-- `feats_dir`: the path for saving processed data
-- `nj`: the number of jobs for data preparation
-- `speed_perturb`: the range of speech perturbed
-- `exp_dir`: the path for saving experimental results
-- `tag`: the suffix of experimental result directory
-
-## Stage 0: Data preparation
-This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
-* `wav.scp`
-```
-BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
-BAC009S0002W0123 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
-BAC009S0002W0124 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
-...
-```
-* `text`
-```
-BAC009S0002W0122 鑰� 瀵� 妤� 甯� 鎴� 浜� 鎶� 鍒� 浣� 鐢� 鏈� 澶� 鐨� 闄� 璐�
-BAC009S0002W0123 涔� 鎴� 涓� 鍦� 鏂� 鏀� 搴� 鐨� 鐪� 涓� 閽�
-BAC009S0002W0124 鑷� 鍏� 鏈� 搴� 鍛� 鍜� 娴� 鐗� 甯� 鐜� 鍏� 瀹� 甯� 鍙� 娑� 闄� 璐� 鍚�
-...
-```
-These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
-
-## Stage 1: Feature Generation
-This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
-* `feats.scp`
-```
-...
-BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
-...
-```
-Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
-* `speech_shape`
-```
-...
-BAC009S0002W0122_sp0.9 665,80
-...
-```
-* `text_shape`
-```
-...
-BAC009S0002W0122_sp0.9 15
-...
-```
-These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
-
-## Stage 2: Dictionary Preparation
-This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
-* `tokens.txt`
-```
-<blank>
-<s>
-</s>
-涓�
-涓�
-...
-榫�
-榫�
-<unk>
-```
-* `<blank>`: indicates the blank token for CTC
-* `<s>`: indicates the start-of-sentence token
-* `</s>`: indicates the end-of-sentence token
-* `<unk>`: indicates the out-of-vocabulary token
-
-## Stage 3: Training
-This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
-
-* DDP Training
-
-We support the DistributedDataParallel (DDP) training and the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). To enable DDP training, please set `gpu_num` greater than 1. For example, if you set `CUDA_VISIBLE_DEVICES=0,1,5,6,7` and `gpu_num=3`, then the gpus with ids 0, 1 and 5 will be used for training.
-
-* DataLoader
-
-We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and users can set `dataset_type=large` to enable it.
-
-* Configuration
-
-The parameters of the training, including model, optimization, dataset, etc., can be set by a YAML file in `conf` directory. Also, users can directly set the parameters in `run.sh` recipe. Please avoid to set the same parameters in both the YAML file and the recipe.
-
-* Training Steps
-
-We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
-
-* Tensorboard
-
-Users can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
-```
-tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
-```
-
-## Stage 4: Decoding
-This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
-
-* Mode Selection
-
-As we support paraformer, uniasr, conformer and other models in FunASR, a `mode` parameter should be specified as `asr/paraformer/uniasr` according to the trained model.
-
-* Configuration
-
-We support CTC decoding, attention decoding and hybrid CTC-attention decoding in FunASR, which can be specified by `ctc_weight` in a YAML file in `conf` directory. Specifically, `ctc_weight=1.0` indicates CTC decoding, `ctc_weight=0.0` indicates attention decoding, `0.0<ctc_weight<1.0` indicates hybrid CTC-attention decoding.
-
-* CPU/GPU Decoding
-
-We support CPU and GPU decoding in FunASR. For CPU decoding, you should set `gpu_inference=False` and set `njob` to specify the total number of CPU decoding jobs. For GPU decoding, you should set `gpu_inference=True`. You should also set `gpuid_list` to indicate which GPUs are used for decoding and `njobs` to indicate the number of decoding jobs on each GPU.
-
-* Performance
-
-We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
-* `text.cer`
-```
-...
-BAC009S0764W0213(nwords=11,cor=11,ins=0,del=0,sub=0) corr=100.00%,cer=0.00%
-ref: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-res: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-...
-```
-
+Undo
diff --git a/docs/academic_recipe/punc_recipe.md b/docs/academic_recipe/punc_recipe.md
index 0306cd3..e9f79bb 100644
--- a/docs/academic_recipe/punc_recipe.md
+++ b/docs/academic_recipe/punc_recipe.md
@@ -1,129 +1,2 @@
# Punctuation Restoration
-Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
-
-## Overall Introduction
-We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
-- `CUDA_VISIBLE_DEVICES`: visible gpu list
-- `gpu_num`: the number of GPUs used for training
-- `gpu_inference`: whether to use GPUs for decoding
-- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
-- `data_aishell`: the raw path of AISHELL-1 dataset
-- `feats_dir`: the path for saving processed data
-- `nj`: the number of jobs for data preparation
-- `speed_perturb`: the range of speech perturbed
-- `exp_dir`: the path for saving experimental results
-- `tag`: the suffix of experimental result directory
-
-## Stage 0: Data preparation
-This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
-* `wav.scp`
-```
-BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
-BAC009S0002W0123 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
-BAC009S0002W0124 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
-...
-```
-* `text`
-```
-BAC009S0002W0122 鑰� 瀵� 妤� 甯� 鎴� 浜� 鎶� 鍒� 浣� 鐢� 鏈� 澶� 鐨� 闄� 璐�
-BAC009S0002W0123 涔� 鎴� 涓� 鍦� 鏂� 鏀� 搴� 鐨� 鐪� 涓� 閽�
-BAC009S0002W0124 鑷� 鍏� 鏈� 搴� 鍛� 鍜� 娴� 鐗� 甯� 鐜� 鍏� 瀹� 甯� 鍙� 娑� 闄� 璐� 鍚�
-...
-```
-These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
-
-## Stage 1: Feature Generation
-This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
-* `feats.scp`
-```
-...
-BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
-...
-```
-Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
-* `speech_shape`
-```
-...
-BAC009S0002W0122_sp0.9 665,80
-...
-```
-* `text_shape`
-```
-...
-BAC009S0002W0122_sp0.9 15
-...
-```
-These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
-
-## Stage 2: Dictionary Preparation
-This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
-* `tokens.txt`
-```
-<blank>
-<s>
-</s>
-涓�
-涓�
-...
-榫�
-榫�
-<unk>
-```
-* `<blank>`: indicates the blank token for CTC
-* `<s>`: indicates the start-of-sentence token
-* `</s>`: indicates the end-of-sentence token
-* `<unk>`: indicates the out-of-vocabulary token
-
-## Stage 3: Training
-This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
-
-* DDP Training
-
-We support the DistributedDataParallel (DDP) training and the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). To enable DDP training, please set `gpu_num` greater than 1. For example, if you set `CUDA_VISIBLE_DEVICES=0,1,5,6,7` and `gpu_num=3`, then the gpus with ids 0, 1 and 5 will be used for training.
-
-* DataLoader
-
-We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and users can set `dataset_type=large` to enable it.
-
-* Configuration
-
-The parameters of the training, including model, optimization, dataset, etc., can be set by a YAML file in `conf` directory. Also, users can directly set the parameters in `run.sh` recipe. Please avoid to set the same parameters in both the YAML file and the recipe.
-
-* Training Steps
-
-We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
-
-* Tensorboard
-
-Users can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
-```
-tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
-```
-
-## Stage 4: Decoding
-This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
-
-* Mode Selection
-
-As we support paraformer, uniasr, conformer and other models in FunASR, a `mode` parameter should be specified as `asr/paraformer/uniasr` according to the trained model.
-
-* Configuration
-
-We support CTC decoding, attention decoding and hybrid CTC-attention decoding in FunASR, which can be specified by `ctc_weight` in a YAML file in `conf` directory. Specifically, `ctc_weight=1.0` indicates CTC decoding, `ctc_weight=0.0` indicates attention decoding, `0.0<ctc_weight<1.0` indicates hybrid CTC-attention decoding.
-
-* CPU/GPU Decoding
-
-We support CPU and GPU decoding in FunASR. For CPU decoding, you should set `gpu_inference=False` and set `njob` to specify the total number of CPU decoding jobs. For GPU decoding, you should set `gpu_inference=True`. You should also set `gpuid_list` to indicate which GPUs are used for decoding and `njobs` to indicate the number of decoding jobs on each GPU.
-
-* Performance
-
-We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
-* `text.cer`
-```
-...
-BAC009S0764W0213(nwords=11,cor=11,ins=0,del=0,sub=0) corr=100.00%,cer=0.00%
-ref: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-res: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-...
-```
-
+Undo
\ No newline at end of file
diff --git a/docs/academic_recipe/sd_recipe.md b/docs/academic_recipe/sd_recipe.md
index 90eb4b3..8b38d7b 100644
--- a/docs/academic_recipe/sd_recipe.md
+++ b/docs/academic_recipe/sd_recipe.md
@@ -1,129 +1,2 @@
# Speaker Diarization
-Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
-
-## Overall Introduction
-We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
-- `CUDA_VISIBLE_DEVICES`: visible gpu list
-- `gpu_num`: the number of GPUs used for training
-- `gpu_inference`: whether to use GPUs for decoding
-- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
-- `data_aishell`: the raw path of AISHELL-1 dataset
-- `feats_dir`: the path for saving processed data
-- `nj`: the number of jobs for data preparation
-- `speed_perturb`: the range of speech perturbed
-- `exp_dir`: the path for saving experimental results
-- `tag`: the suffix of experimental result directory
-
-## Stage 0: Data preparation
-This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
-* `wav.scp`
-```
-BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
-BAC009S0002W0123 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
-BAC009S0002W0124 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
-...
-```
-* `text`
-```
-BAC009S0002W0122 鑰� 瀵� 妤� 甯� 鎴� 浜� 鎶� 鍒� 浣� 鐢� 鏈� 澶� 鐨� 闄� 璐�
-BAC009S0002W0123 涔� 鎴� 涓� 鍦� 鏂� 鏀� 搴� 鐨� 鐪� 涓� 閽�
-BAC009S0002W0124 鑷� 鍏� 鏈� 搴� 鍛� 鍜� 娴� 鐗� 甯� 鐜� 鍏� 瀹� 甯� 鍙� 娑� 闄� 璐� 鍚�
-...
-```
-These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
-
-## Stage 1: Feature Generation
-This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
-* `feats.scp`
-```
-...
-BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
-...
-```
-Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
-* `speech_shape`
-```
-...
-BAC009S0002W0122_sp0.9 665,80
-...
-```
-* `text_shape`
-```
-...
-BAC009S0002W0122_sp0.9 15
-...
-```
-These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
-
-## Stage 2: Dictionary Preparation
-This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
-* `tokens.txt`
-```
-<blank>
-<s>
-</s>
-涓�
-涓�
-...
-榫�
-榫�
-<unk>
-```
-* `<blank>`: indicates the blank token for CTC
-* `<s>`: indicates the start-of-sentence token
-* `</s>`: indicates the end-of-sentence token
-* `<unk>`: indicates the out-of-vocabulary token
-
-## Stage 3: Training
-This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
-
-* DDP Training
-
-We support the DistributedDataParallel (DDP) training and the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). To enable DDP training, please set `gpu_num` greater than 1. For example, if you set `CUDA_VISIBLE_DEVICES=0,1,5,6,7` and `gpu_num=3`, then the gpus with ids 0, 1 and 5 will be used for training.
-
-* DataLoader
-
-We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and users can set `dataset_type=large` to enable it.
-
-* Configuration
-
-The parameters of the training, including model, optimization, dataset, etc., can be set by a YAML file in `conf` directory. Also, users can directly set the parameters in `run.sh` recipe. Please avoid to set the same parameters in both the YAML file and the recipe.
-
-* Training Steps
-
-We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
-
-* Tensorboard
-
-Users can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
-```
-tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
-```
-
-## Stage 4: Decoding
-This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
-
-* Mode Selection
-
-As we support paraformer, uniasr, conformer and other models in FunASR, a `mode` parameter should be specified as `asr/paraformer/uniasr` according to the trained model.
-
-* Configuration
-
-We support CTC decoding, attention decoding and hybrid CTC-attention decoding in FunASR, which can be specified by `ctc_weight` in a YAML file in `conf` directory. Specifically, `ctc_weight=1.0` indicates CTC decoding, `ctc_weight=0.0` indicates attention decoding, `0.0<ctc_weight<1.0` indicates hybrid CTC-attention decoding.
-
-* CPU/GPU Decoding
-
-We support CPU and GPU decoding in FunASR. For CPU decoding, you should set `gpu_inference=False` and set `njob` to specify the total number of CPU decoding jobs. For GPU decoding, you should set `gpu_inference=True`. You should also set `gpuid_list` to indicate which GPUs are used for decoding and `njobs` to indicate the number of decoding jobs on each GPU.
-
-* Performance
-
-We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
-* `text.cer`
-```
-...
-BAC009S0764W0213(nwords=11,cor=11,ins=0,del=0,sub=0) corr=100.00%,cer=0.00%
-ref: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-res: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-...
-```
-
+Undo
diff --git a/docs/academic_recipe/sv_recipe.md b/docs/academic_recipe/sv_recipe.md
index 0eebe3d..7fe493b 100644
--- a/docs/academic_recipe/sv_recipe.md
+++ b/docs/academic_recipe/sv_recipe.md
@@ -1,129 +1,2 @@
# Speaker Verification
-Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
-
-## Overall Introduction
-We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
-- `CUDA_VISIBLE_DEVICES`: visible gpu list
-- `gpu_num`: the number of GPUs used for training
-- `gpu_inference`: whether to use GPUs for decoding
-- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
-- `data_aishell`: the raw path of AISHELL-1 dataset
-- `feats_dir`: the path for saving processed data
-- `nj`: the number of jobs for data preparation
-- `speed_perturb`: the range of speech perturbed
-- `exp_dir`: the path for saving experimental results
-- `tag`: the suffix of experimental result directory
-
-## Stage 0: Data preparation
-This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
-* `wav.scp`
-```
-BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
-BAC009S0002W0123 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
-BAC009S0002W0124 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
-...
-```
-* `text`
-```
-BAC009S0002W0122 鑰� 瀵� 妤� 甯� 鎴� 浜� 鎶� 鍒� 浣� 鐢� 鏈� 澶� 鐨� 闄� 璐�
-BAC009S0002W0123 涔� 鎴� 涓� 鍦� 鏂� 鏀� 搴� 鐨� 鐪� 涓� 閽�
-BAC009S0002W0124 鑷� 鍏� 鏈� 搴� 鍛� 鍜� 娴� 鐗� 甯� 鐜� 鍏� 瀹� 甯� 鍙� 娑� 闄� 璐� 鍚�
-...
-```
-These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
-
-## Stage 1: Feature Generation
-This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
-* `feats.scp`
-```
-...
-BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
-...
-```
-Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
-* `speech_shape`
-```
-...
-BAC009S0002W0122_sp0.9 665,80
-...
-```
-* `text_shape`
-```
-...
-BAC009S0002W0122_sp0.9 15
-...
-```
-These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
-
-## Stage 2: Dictionary Preparation
-This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
-* `tokens.txt`
-```
-<blank>
-<s>
-</s>
-涓�
-涓�
-...
-榫�
-榫�
-<unk>
-```
-* `<blank>`: indicates the blank token for CTC
-* `<s>`: indicates the start-of-sentence token
-* `</s>`: indicates the end-of-sentence token
-* `<unk>`: indicates the out-of-vocabulary token
-
-## Stage 3: Training
-This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
-
-* DDP Training
-
-We support the DistributedDataParallel (DDP) training and the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). To enable DDP training, please set `gpu_num` greater than 1. For example, if you set `CUDA_VISIBLE_DEVICES=0,1,5,6,7` and `gpu_num=3`, then the gpus with ids 0, 1 and 5 will be used for training.
-
-* DataLoader
-
-We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and users can set `dataset_type=large` to enable it.
-
-* Configuration
-
-The parameters of the training, including model, optimization, dataset, etc., can be set by a YAML file in `conf` directory. Also, users can directly set the parameters in `run.sh` recipe. Please avoid to set the same parameters in both the YAML file and the recipe.
-
-* Training Steps
-
-We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
-
-* Tensorboard
-
-Users can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
-```
-tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
-```
-
-## Stage 4: Decoding
-This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
-
-* Mode Selection
-
-As we support paraformer, uniasr, conformer and other models in FunASR, a `mode` parameter should be specified as `asr/paraformer/uniasr` according to the trained model.
-
-* Configuration
-
-We support CTC decoding, attention decoding and hybrid CTC-attention decoding in FunASR, which can be specified by `ctc_weight` in a YAML file in `conf` directory. Specifically, `ctc_weight=1.0` indicates CTC decoding, `ctc_weight=0.0` indicates attention decoding, `0.0<ctc_weight<1.0` indicates hybrid CTC-attention decoding.
-
-* CPU/GPU Decoding
-
-We support CPU and GPU decoding in FunASR. For CPU decoding, you should set `gpu_inference=False` and set `njob` to specify the total number of CPU decoding jobs. For GPU decoding, you should set `gpu_inference=True`. You should also set `gpuid_list` to indicate which GPUs are used for decoding and `njobs` to indicate the number of decoding jobs on each GPU.
-
-* Performance
-
-We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
-* `text.cer`
-```
-...
-BAC009S0764W0213(nwords=11,cor=11,ins=0,del=0,sub=0) corr=100.00%,cer=0.00%
-ref: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-res: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-...
-```
-
+Undo
diff --git a/docs/academic_recipe/vad_recipe.md b/docs/academic_recipe/vad_recipe.md
index 6aa7532..0216bc3 100644
--- a/docs/academic_recipe/vad_recipe.md
+++ b/docs/academic_recipe/vad_recipe.md
@@ -1,129 +1,2 @@
# Voice Activity Detection
-Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
-
-## Overall Introduction
-We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
-- `CUDA_VISIBLE_DEVICES`: visible gpu list
-- `gpu_num`: the number of GPUs used for training
-- `gpu_inference`: whether to use GPUs for decoding
-- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
-- `data_aishell`: the raw path of AISHELL-1 dataset
-- `feats_dir`: the path for saving processed data
-- `nj`: the number of jobs for data preparation
-- `speed_perturb`: the range of speech perturbed
-- `exp_dir`: the path for saving experimental results
-- `tag`: the suffix of experimental result directory
-
-## Stage 0: Data preparation
-This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
-* `wav.scp`
-```
-BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
-BAC009S0002W0123 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
-BAC009S0002W0124 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
-...
-```
-* `text`
-```
-BAC009S0002W0122 鑰� 瀵� 妤� 甯� 鎴� 浜� 鎶� 鍒� 浣� 鐢� 鏈� 澶� 鐨� 闄� 璐�
-BAC009S0002W0123 涔� 鎴� 涓� 鍦� 鏂� 鏀� 搴� 鐨� 鐪� 涓� 閽�
-BAC009S0002W0124 鑷� 鍏� 鏈� 搴� 鍛� 鍜� 娴� 鐗� 甯� 鐜� 鍏� 瀹� 甯� 鍙� 娑� 闄� 璐� 鍚�
-...
-```
-These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
-
-## Stage 1: Feature Generation
-This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
-* `feats.scp`
-```
-...
-BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
-...
-```
-Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
-* `speech_shape`
-```
-...
-BAC009S0002W0122_sp0.9 665,80
-...
-```
-* `text_shape`
-```
-...
-BAC009S0002W0122_sp0.9 15
-...
-```
-These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
-
-## Stage 2: Dictionary Preparation
-This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
-* `tokens.txt`
-```
-<blank>
-<s>
-</s>
-涓�
-涓�
-...
-榫�
-榫�
-<unk>
-```
-* `<blank>`: indicates the blank token for CTC
-* `<s>`: indicates the start-of-sentence token
-* `</s>`: indicates the end-of-sentence token
-* `<unk>`: indicates the out-of-vocabulary token
-
-## Stage 3: Training
-This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
-
-* DDP Training
-
-We support the DistributedDataParallel (DDP) training and the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). To enable DDP training, please set `gpu_num` greater than 1. For example, if you set `CUDA_VISIBLE_DEVICES=0,1,5,6,7` and `gpu_num=3`, then the gpus with ids 0, 1 and 5 will be used for training.
-
-* DataLoader
-
-We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and users can set `dataset_type=large` to enable it.
-
-* Configuration
-
-The parameters of the training, including model, optimization, dataset, etc., can be set by a YAML file in `conf` directory. Also, users can directly set the parameters in `run.sh` recipe. Please avoid to set the same parameters in both the YAML file and the recipe.
-
-* Training Steps
-
-We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
-
-* Tensorboard
-
-Users can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
-```
-tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
-```
-
-## Stage 4: Decoding
-This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
-
-* Mode Selection
-
-As we support paraformer, uniasr, conformer and other models in FunASR, a `mode` parameter should be specified as `asr/paraformer/uniasr` according to the trained model.
-
-* Configuration
-
-We support CTC decoding, attention decoding and hybrid CTC-attention decoding in FunASR, which can be specified by `ctc_weight` in a YAML file in `conf` directory. Specifically, `ctc_weight=1.0` indicates CTC decoding, `ctc_weight=0.0` indicates attention decoding, `0.0<ctc_weight<1.0` indicates hybrid CTC-attention decoding.
-
-* CPU/GPU Decoding
-
-We support CPU and GPU decoding in FunASR. For CPU decoding, you should set `gpu_inference=False` and set `njob` to specify the total number of CPU decoding jobs. For GPU decoding, you should set `gpu_inference=True`. You should also set `gpuid_list` to indicate which GPUs are used for decoding and `njobs` to indicate the number of decoding jobs on each GPU.
-
-* Performance
-
-We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
-* `text.cer`
-```
-...
-BAC009S0764W0213(nwords=11,cor=11,ins=0,del=0,sub=0) corr=100.00%,cer=0.00%
-ref: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-res: 鏋� 寤� 鑹� 濂� 鐨� 鏃� 娓� 甯� 鍦� 鐜� 澧�
-...
-```
-
+Undo
diff --git a/docs/index.rst b/docs/index.rst
index e6aff5f..c2656bd 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -17,8 +17,8 @@
:maxdepth: 1
:caption: Installation
- ./installation.md
- ./docker.md
+ ./installation/installation.md
+ ./installation/docker.md
.. toctree::
:maxdepth: 1
@@ -44,6 +44,7 @@
./modelscope_pipeline/tp_pipeline.md
./modelscope_pipeline/sv_pipeline.md
./modelscope_pipeline/sd_pipeline.md
+ ./modelscope_pipeline/itn_pipeline.md
.. toctree::
:maxdepth: 1
@@ -56,8 +57,8 @@
:maxdepth: 1
:caption: Model Zoo
- ./modelscope_models.md
- ./huggingface_models.md
+ ./model_zoo/modelscope_models.md
+ ./model_zoo/huggingface_models.md
.. toctree::
:maxdepth: 1
@@ -70,6 +71,7 @@
./runtime/grpc_python.md
./runtime/grpc_cpp.md
./runtime/websocket_python.md
+ ./runtime/websocket_cpp.md
.. toctree::
:maxdepth: 1
@@ -84,25 +86,25 @@
:maxdepth: 1
:caption: Funasr Library
- ./build_task.md
+ ./reference/build_task.md
.. toctree::
:maxdepth: 1
:caption: Papers
- ./papers.md
+ ./reference/papers.md
.. toctree::
:maxdepth: 1
:caption: Application
- ./application.md
+ ./reference/application.md
.. toctree::
:maxdepth: 1
:caption: FQA
- ./FQA.md
+ ./reference/FQA.md
Indices and tables
diff --git a/docs/docker.md b/docs/installation/docker.md
similarity index 100%
rename from docs/docker.md
rename to docs/installation/docker.md
diff --git a/docs/installation.md b/docs/installation/installation.md
similarity index 100%
rename from docs/installation.md
rename to docs/installation/installation.md
diff --git a/docs/huggingface_models.md b/docs/model_zoo/huggingface_models.md
similarity index 100%
rename from docs/huggingface_models.md
rename to docs/model_zoo/huggingface_models.md
diff --git a/docs/model_zoo/modelscope_models.md b/docs/model_zoo/modelscope_models.md
new file mode 100644
index 0000000..1b7f475
--- /dev/null
+++ b/docs/model_zoo/modelscope_models.md
@@ -0,0 +1,126 @@
+# Pretrained Models on ModelScope
+
+## Model License
+- Apache License 2.0
+
+## Model Zoo
+Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
+
+### Speech Recognition Models
+#### Paraformer Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
+| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
+| [Paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
+| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
+| [Paraformer-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
+| [Paraformer-large-online](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Online | Which could deal with streaming input |
+| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
+| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+
+
+#### UniASR Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:-------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------:|:---------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000 hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
+| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000 hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
+| [UniASR English](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/summary) | EN | Alibaba Speech Data (10000 hours) | 1080 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR Russian](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/summary) | RU | Alibaba Speech Data (5000 hours) | 1664 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR Japanese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/summary) | JA | Alibaba Speech Data (5000 hours) | 5977 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Korean](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/summary) | KO | Alibaba Speech Data (2000 hours) | 6400 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR Cantonese (CHS)](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/summary) | Cantonese (CHS) | Alibaba Speech Data (5000 hours) | 1468 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR Indonesian](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/summary) | ID | Alibaba Speech Data (1000 hours) | 1067 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Vietnamese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/summary) | VI | Alibaba Speech Data (1000 hours) | 1001 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Spanish](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/summary) | ES | Alibaba Speech Data (1000 hours) | 3445 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR Portuguese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/summary) | PT | Alibaba Speech Data (1000 hours) | 1617 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR French](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/summary) | FR | Alibaba Speech Data (1000 hours) | 3472 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR German](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/summary) | GE | Alibaba Speech Data (1000 hours) | 3690 | 95M | Online | UniASR streaming online unifying models |
+| [UniASR Persian](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/summary) | FA | Alibaba Speech Data (1000 hours) | 1257 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | MY | Alibaba Speech Data (1000 hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | HE | Alibaba Speech Data (1000 hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | UR | Alibaba Speech Data (1000 hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
+
+
+
+#### Conformer Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
+| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
+| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) | EN | Alibaba Speech Data (10000hours) | 4199 | 220M | Offline | Duration of input wav <= 20s |
+
+
+#### RNN-T Models
+
+### Multi-talker Speech Recognition Models
+
+#### MFCCA Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:--------:|:------------------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary) | CN | AliMeeting銆丄ISHELL-4銆丼imudata (917hours) | 4950 | 45M | Offline | Duration of input wav <= 20s, channel of input wav <= 8 channel |
+
+
+
+### Voice Activity Detection Models
+
+| Model Name | Training Data | Parameters | Sampling Rate | Notes |
+|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
+
+### Punctuation Restoration Models
+
+| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
+
+### Language Models
+
+| Model Name | Training Data | Parameters | Vocab Size | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
+| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
+
+### Speaker Verification Models
+
+| Model Name | Training Data | Parameters | Number Speaker | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (1,200 hours) | 17.5M | 3465 | Xvector, speaker verification, Chinese |
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (60 hours) | 61M | 6135 | Xvector, speaker verification, English |
+
+### Speaker Diarization Models
+
+| Model Name | Training Data | Parameters | Notes |
+|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (120 hours) | 40.5M | Speaker diarization, profiles and records, Chinese |
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (60 hours) | 12M | Speaker diarization, profiles and records, English |
+
+### Timestamp Prediction Models
+
+| Model Name | Language | Training Data | Parameters | Notes |
+|:--------------------------------------------------------------------------------------------------:|:--------------:|:-------------------:|:----------:|:------|
+| [TP-Aligner](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) | CN | Alibaba Speech Data (50000hours) | 37.8M | Timestamp prediction, Mandarin, middle size |
+
+### Inverse Text Normalization (ITN) Models
+
+| Model Name | Language | Parameters | Notes |
+|:----------------------------------------------------------------------------------------------------------------:|:--------:|:----------:|:-------------------------|
+| [English](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-en/summary) | EN | 1.54M | ITN, ASR post-processing |
+| [Russian](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-ru/summary) | RU | 17.79M | ITN, ASR post-processing |
+| [Japanese](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-ja/summary) | JA | 6.8M | ITN, ASR post-processing |
+| [Korean](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-ko/summary) | KO | 1.28M | ITN, ASR post-processing |
+| [Indonesian](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-id/summary) | ID | 2.06M | ITN, ASR post-processing |
+| [Vietnamese](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-vi/summary) | VI | 0.92M | ITN, ASR post-processing |
+| [Tagalog](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-tl/summary) | TL | 0.65M | ITN, ASR post-processing |
+| [Spanish](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-es/summary) | ES | 1.32M | ITN, ASR post-processing |
+| [Portuguese](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-pt/summary) | PT | 1.28M | ITN, ASR post-processing |
+| [French](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-fr/summary) | FR | 4.39M | ITN, ASR post-processing |
+| [German](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-de/summary)| GE | 3.95M | ITN, ASR post-processing |
diff --git a/docs/modelscope_models.md b/docs/modelscope_models.md
deleted file mode 100644
index 5f94a09..0000000
--- a/docs/modelscope_models.md
+++ /dev/null
@@ -1,94 +0,0 @@
-# Pretrained Models on ModelScope
-
-## Model License
-- Apache License 2.0
-
-## Model Zoo
-Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
-
-### Speech Recognition Models
-#### Paraformer Models
-
-| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
-|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
-| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
-| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
-| [Paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
-| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
-| [Paraformer-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
-| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
-| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
-| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
-| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
-| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
-
-
-#### UniASR Models
-
-| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
-|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
-| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
-| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
-| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
-| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
-| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
-
-#### Conformer Models
-
-| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
-|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
-| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
-| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
-
-
-#### RNN-T Models
-
-### Multi-talker Speech Recognition Models
-
-#### MFCCA Models
-
-| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
-|:-------------------------------------------------------------------------------------------------------------:|:--------:|:------------------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
-| [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary) | CN | AliMeeting銆丄ISHELL-4銆丼imudata (917hours) | 4950 | 45M | Offline | Duration of input wav <= 20s, channel of input wav <= 8 channel |
-
-
-
-### Voice Activity Detection Models
-
-| Model Name | Training Data | Parameters | Sampling Rate | Notes |
-|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
-| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
-| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
-
-### Punctuation Restoration Models
-
-| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
-|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
-| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
-| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
-
-### Language Models
-
-| Model Name | Training Data | Parameters | Vocab Size | Notes |
-|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
-| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
-
-### Speaker Verification Models
-
-| Model Name | Training Data | Parameters | Number Speaker | Notes |
-|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
-| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (1,200 hours) | 17.5M | 3465 | Xvector, speaker verification, Chinese |
-| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (60 hours) | 61M | 6135 | Xvector, speaker verification, English |
-
-### Speaker Diarization Models
-
-| Model Name | Training Data | Parameters | Notes |
-|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
-| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (120 hours) | 40.5M | Speaker diarization, profiles and records, Chinese |
-| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (60 hours) | 12M | Speaker diarization, profiles and records, English |
-
-### Timestamp Prediction Models
-
-| Model Name | Language | Training Data | Parameters | Notes |
-|:--------------------------------------------------------------------------------------------------:|:--------------:|:-------------------:|:----------:|:------|
-| [TP-Aligner](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) | CN | Alibaba Speech Data (50000hours) | 37.8M | Timestamp prediction, Mandarin, middle size |
diff --git a/docs/modelscope_pipeline/itn_pipeline.md b/docs/modelscope_pipeline/itn_pipeline.md
new file mode 100644
index 0000000..2336842
--- /dev/null
+++ b/docs/modelscope_pipeline/itn_pipeline.md
@@ -0,0 +1,63 @@
+# Inverse Text Normalization (ITN)
+
+> **Note**:
+> The modelscope pipeline supports all the models in [model zoo](https://modelscope.cn/models?page=1&tasks=inverse-text-processing&type=audio) to inference. Here we take the model of the Japanese ITN model as example to demonstrate the usage.
+
+## Inference
+
+### Quick start
+#### [Japanese ITN model](https://modelscope.cn/models/damo/speech_inverse_text_processing_fun-text-processing-itn-ja/summary)
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+itn_inference_pipline = pipeline(
+ task=Tasks.inverse_text_processing,
+ model='damo/speech_inverse_text_processing_fun-text-processing-itn-ja',
+ model_revision=None)
+
+itn_result = itn_inference_pipline(text_in='鐧句簩鍗佷笁')
+print(itn_result)
+# 123
+```
+- read text data directly.
+```python
+rec_result = inference_pipeline(text_in='涓�涔濅節涔濆勾銇獣鐢熴仐銇熷悓鍟嗗搧銇仭銇伩銆佺磩涓夊崄骞村墠銆佷簩鍗佸洓姝炽伄闋冦伄骞稿洓閮庛伄鍐欑湡銈掑叕闁嬨��')
+# 1999骞淬伀瑾曠敓銇椼仧鍚屽晢鍝併伀銇°仾銇裤�佺磩30骞村墠銆�24姝炽伄闋冦伄骞稿洓閮庛伄鍐欑湡銈掑叕闁嬨��
+```
+- text stored via url锛宔xample锛歨ttps://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/ja_itn_example.txt
+```python
+rec_result = inference_pipeline(text_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/ja_itn_example.txt')
+```
+
+Full code of demo, please ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/fun_text_processing/inverse_text_normalization)
+
+### API-reference
+#### Define pipeline
+- `task`: `Tasks.inverse_text_processing`
+- `model`: model name in [model zoo](https://modelscope.cn/models?page=1&tasks=inverse-text-processing&type=audio), or model path in local disk
+- `output_dir`: `None` (Default), the output path of results if set
+- `model_revision`: `None` (Default), setting the model version
+
+#### Infer pipeline
+- `text_in`: the input to decode, which could be:
+ - text bytes, `e.g.`: "涓�涔濅節涔濆勾銇獣鐢熴仐銇熷悓鍟嗗搧銇仭銇伩銆佺磩涓夊崄骞村墠銆佷簩鍗佸洓姝炽伄闋冦伄骞稿洓閮庛伄鍐欑湡銈掑叕闁嬨��"
+ - text file, `e.g.`: https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/ja_itn_example.txt
+ In this case of `text file` input, `output_dir` must be set to save the output results
+
+## Modify Your Own ITN Model
+The rule-based ITN code is open-sourced in [FunTextProcessing](https://github.com/alibaba-damo-academy/FunASR/tree/main/fun_text_processing), users can modify by their own grammar rules for different languages. Let's take Japanese as an example, users can add their own whitelist in ```FunASR/fun_text_processing/inverse_text_normalization/ja/data/whitelist.tsv```. After modified the grammar rules, the users can export and evaluate their own ITN models in local directory.
+
+### Export ITN Model
+Export ITN model via ```FunASR/fun_text_processing/inverse_text_normalization/export_models.py```. An example to export ITN model to local folder is shown as below.
+```shell
+cd FunASR/fun_text_processing/inverse_text_normalization/
+python export_models.py --language ja --export_dir ./itn_models/
+```
+
+### Evaluate ITN Model
+Users can evaluate their own ITN model in local directory via ```FunASR/fun_text_processing/inverse_text_normalization/inverse_normalize.py```. Here is an example:
+```shell
+cd FunASR/fun_text_processing/inverse_text_normalization/
+python inverse_normalize.py --input_file ja_itn_example.txt --cache_dir ./itn_models/ --output_file output.txt --language=ja
+```
\ No newline at end of file
diff --git a/docs/modelscope_pipeline/punc_pipeline.md b/docs/modelscope_pipeline/punc_pipeline.md
new file mode 120000
index 0000000..4ef4711
--- /dev/null
+++ b/docs/modelscope_pipeline/punc_pipeline.md
@@ -0,0 +1 @@
+../../egs_modelscope/punctuation/TEMPLATE/README.md
\ No newline at end of file
diff --git a/docs/modelscope_pipeline/quick_start.md b/docs/modelscope_pipeline/quick_start.md
index 436fb1d..7e35e91 100644
--- a/docs/modelscope_pipeline/quick_start.md
+++ b/docs/modelscope_pipeline/quick_start.md
@@ -1,7 +1,7 @@
# Quick Start
> **Note**:
-> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take typic model as example to demonstrate the usage.
+> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take typic model as example to demonstrate the usage.
## Inference with pipeline
diff --git a/docs/FQA.md b/docs/reference/FQA.md
similarity index 100%
rename from docs/FQA.md
rename to docs/reference/FQA.md
diff --git a/docs/application.md b/docs/reference/application.md
similarity index 100%
rename from docs/application.md
rename to docs/reference/application.md
diff --git a/docs/build_task.md b/docs/reference/build_task.md
similarity index 100%
rename from docs/build_task.md
rename to docs/reference/build_task.md
diff --git a/docs/papers.md b/docs/reference/papers.md
similarity index 100%
rename from docs/papers.md
rename to docs/reference/papers.md
diff --git a/docs/runtime/websocket_cpp.md b/docs/runtime/websocket_cpp.md
new file mode 120000
index 0000000..8a87df5
--- /dev/null
+++ b/docs/runtime/websocket_cpp.md
@@ -0,0 +1 @@
+../../funasr/runtime/websocket/readme.md
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa-asr/README.md
new file mode 100644
index 0000000..882345c
--- /dev/null
+++ b/egs/alimeeting/sa-asr/README.md
@@ -0,0 +1,79 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
+To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html))
+There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.
+Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
+```shell
+dataset
+|鈥斺�� Eval_Ali_far
+|鈥斺�� Eval_Ali_near
+|鈥斺�� Test_Ali_far
+|鈥斺�� Test_Ali_near
+|鈥斺�� Train_Ali_far
+|鈥斺�� Train_Ali_near
+```
+There are 18 stages in `run.sh`:
+```shell
+stage 1 - 5: Data preparation and processing.
+stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
+stage 7 - 9: Language model training (Optional).
+stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
+stage 12: SA-ASR training.
+stage 13 - 18: Inference and evaluation.
+```
+Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.
+```shell
+data/Test_2023_Ali_far
+|鈥斺�� wav.scp
+|鈥斺�� wav_raw.scp
+|鈥斺�� segments
+|鈥斺�� utt2spk
+|鈥斺�� spk2utt
+```
+There are 4 stages in `run_m2met_2023_infer.sh`:
+```shell
+stage 1: Data preparation and processing.
+stage 2: Generate speaker profiles for inference.
+stage 3: Inference.
+stage 4: Generation of SA-ASR results required for final submission.
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.
+<table>
+ <tr >
+ <td rowspan="2"></td>
+ <td colspan="2">SI-CER(%)</td>
+ <td colspan="2">cpCER(%)</td>
+ </tr>
+ <tr>
+ <td>Eval</td>
+ <td>Test</td>
+ <td>Eval</td>
+ <td>Test</td>
+ </tr>
+ <tr>
+ <td>oracle profile</td>
+ <td>31.93</td>
+ <td>32.75</td>
+ <td>48.56</td>
+ <td>53.33</td>
+ </tr>
+ <tr>
+ <td>cluster profile</td>
+ <td>31.94</td>
+ <td>32.77</td>
+ <td>55.49</td>
+ <td>58.17</td>
+ </tr>
+</table>
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413鈥�4417.
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
new file mode 100755
index 0000000..f8cdcd3
--- /dev/null
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -0,0 +1,1572 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+min() {
+ local a b
+ a=$1
+ for b in "$@"; do
+ if [ "${b}" -le "${a}" ]; then
+ a="${b}"
+ fi
+ done
+ echo "${a}"
+}
+SECONDS=0
+
+# General configuration
+stage=1 # Processes starts from the specified stage.
+stop_stage=10000 # Processes is stopped at the specified stage.
+skip_data_prep=false # Skip data preparation stages.
+skip_train=false # Skip training stages.
+skip_eval=false # Skip decoding and evaluation stages.
+skip_upload=true # Skip packing and uploading stages.
+ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
+num_nodes=1 # The number of nodes.
+nj=16 # The number of parallel jobs.
+inference_nj=16 # The number of parallel jobs in decoding.
+gpu_inference=false # Whether to perform gpu decoding.
+njob_infer=4
+dumpdir=dump2 # Directory to dump features.
+expdir=exp # Directory to save experiments.
+python=python3 # Specify python to execute espnet commands.
+device=0
+
+# Data preparation related
+local_data_opts= # The options given to local/data.sh.
+
+# Speed perturbation related
+speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
+
+# Feature extraction related
+feats_type=raw # Feature type (raw or fbank_pitch).
+audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
+fs=16000 # Sampling rate.
+min_wav_duration=0.1 # Minimum duration in second.
+max_wav_duration=20 # Maximum duration in second.
+
+# Tokenization related
+token_type=bpe # Tokenization type (char or bpe).
+nbpe=30 # The number of BPE vocabulary.
+bpemode=unigram # Mode of BPE (unigram or bpe).
+oov="<unk>" # Out of vocabulary symbol.
+blank="<blank>" # CTC blank symbol
+sos_eos="<sos/eos>" # sos and eos symbole
+bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
+bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
+bpe_char_cover=1.0 # character coverage when modeling BPE
+
+# Language model related
+use_lm=true # Use language model for ASR decoding.
+lm_tag= # Suffix to the result dir for language model training.
+lm_exp= # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored.
+lm_stats_dir= # Specify the direcotry path for LM statistics.
+lm_config= # Config for language model training.
+lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in lm config.
+use_word_lm=false # Whether to use word language model.
+num_splits_lm=1 # Number of splitting for lm corpus.
+# shellcheck disable=SC2034
+word_vocab_size=10000 # Size of word vocabulary.
+
+# ASR model related
+asr_tag= # Suffix to the result dir for asr model training.
+asr_exp= # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored.
+sa_asr_exp=
+asr_stats_dir= # Specify the direcotry path for ASR statistics.
+asr_config= # Config for asr model training.
+sa_asr_config=
+asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in asr config.
+feats_normalize=global_mvn # Normalizaton layer type.
+num_splits_asr=1 # Number of splitting for lm corpus.
+
+# Decoding related
+inference_tag= # Suffix to the result dir for decoding.
+inference_config= # Config for decoding.
+inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
+ # Note that it will overwrite args in inference config.
+sa_asr_inference_tag=
+sa_asr_inference_args=
+
+inference_lm=valid.loss.ave.pb # Language modle path for decoding.
+inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
+ # e.g.
+ # inference_asr_model=train.loss.best.pth
+ # inference_asr_model=3epoch.pth
+ # inference_asr_model=valid.acc.best.pth
+ # inference_asr_model=valid.loss.ave.pth
+inference_sa_asr_model=valid.acc_spk.ave.pb
+download_model= # Download a model from Model Zoo and use it for decoding.
+
+# [Task dependent] Set the datadir name created by local/data.sh
+train_set= # Name of training set.
+valid_set= # Name of validation set used for monitoring/tuning network training.
+test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
+bpe_train_text= # Text file path of bpe training set.
+lm_train_text= # Text file path of language model training set.
+lm_dev_text= # Text file path of language model development set.
+lm_test_text= # Text file path of language model evaluation set.
+nlsyms_txt=none # Non-linguistic symbol list if existing.
+cleaner=none # Text cleaner.
+g2p=none # g2p method (needed if token_type=phn).
+lang=zh # The language type of corpus.
+score_opts= # The options given to sclite scoring
+local_score_opts= # The options given to local/score.sh.
+
+
+help_message=$(cat << EOF
+Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
+
+Options:
+ # General configuration
+ --stage # Processes starts from the specified stage (default="${stage}").
+ --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
+ --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
+ --skip_train # Skip training stages (default="${skip_train}").
+ --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
+ --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
+ --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
+ --num_nodes # The number of nodes (default="${num_nodes}").
+ --nj # The number of parallel jobs (default="${nj}").
+ --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
+ --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
+ --dumpdir # Directory to dump features (default="${dumpdir}").
+ --expdir # Directory to save experiments (default="${expdir}").
+ --python # Specify python to execute espnet commands (default="${python}").
+ --device # Which GPUs are use for local training (defalut="${device}").
+
+ # Data preparation related
+ --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
+
+ # Speed perturbation related
+ --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
+
+ # Feature extraction related
+ --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
+ --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
+ --fs # Sampling rate (default="${fs}").
+ --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
+ --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
+
+ # Tokenization related
+ --token_type # Tokenization type (char or bpe, default="${token_type}").
+ --nbpe # The number of BPE vocabulary (default="${nbpe}").
+ --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
+ --oov # Out of vocabulary symbol (default="${oov}").
+ --blank # CTC blank symbol (default="${blank}").
+ --sos_eos # sos and eos symbole (default="${sos_eos}").
+ --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
+ --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
+ --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
+
+ # Language model related
+ --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
+ --lm_exp # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored (default="${lm_exp}").
+ --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
+ --lm_config # Config for language model training (default="${lm_config}").
+ --lm_args # Arguments for language model training (default="${lm_args}").
+ # e.g., --lm_args "--max_epoch 10"
+ # Note that it will overwrite args in lm config.
+ --use_word_lm # Whether to use word language model (default="${use_word_lm}").
+ --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
+ --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
+
+ # ASR model related
+ --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
+ --asr_exp # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored (default="${asr_exp}").
+ --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
+ --asr_config # Config for asr model training (default="${asr_config}").
+ --asr_args # Arguments for asr model training (default="${asr_args}").
+ # e.g., --asr_args "--max_epoch 10"
+ # Note that it will overwrite args in asr config.
+ --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
+ --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
+
+ # Decoding related
+ --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
+ --inference_config # Config for decoding (default="${inference_config}").
+ --inference_args # Arguments for decoding (default="${inference_args}").
+ # e.g., --inference_args "--lm_weight 0.1"
+ # Note that it will overwrite args in inference config.
+ --inference_lm # Language modle path for decoding (default="${inference_lm}").
+ --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
+ --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+
+ # [Task dependent] Set the datadir name created by local/data.sh
+ --train_set # Name of training set (required).
+ --valid_set # Name of validation set used for monitoring/tuning network training (required).
+ --test_sets # Names of test sets.
+ # Multiple items (e.g., both dev and eval sets) can be specified (required).
+ --bpe_train_text # Text file path of bpe training set.
+ --lm_train_text # Text file path of language model training set.
+ --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
+ --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
+ --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
+ --cleaner # Text cleaner (default="${cleaner}").
+ --g2p # g2p method (default="${g2p}").
+ --lang # The language type of corpus (default=${lang}).
+ --score_opts # The options given to sclite scoring (default="{score_opts}").
+ --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
+EOF
+)
+
+log "$0 $*"
+# Save command line args for logging (they will be lost after utils/parse_options.sh)
+run_args=$(python -m funasr.utils.cli_utils $0 "$@")
+. utils/parse_options.sh
+
+if [ $# -ne 0 ]; then
+ log "${help_message}"
+ log "Error: No positional arguments are required."
+ exit 2
+fi
+
+. ./path.sh
+
+
+# Check required arguments
+[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
+[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
+[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
+
+# Check feature type
+if [ "${feats_type}" = raw ]; then
+ data_feats=${dumpdir}/raw
+elif [ "${feats_type}" = fbank_pitch ]; then
+ data_feats=${dumpdir}/fbank_pitch
+elif [ "${feats_type}" = fbank ]; then
+ data_feats=${dumpdir}/fbank
+elif [ "${feats_type}" == extracted ]; then
+ data_feats=${dumpdir}/extracted
+else
+ log "${help_message}"
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+fi
+
+# Use the same text as ASR for bpe training if not specified.
+[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
+# Use the text of the 1st evaldir if lm_test is not specified
+[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
+
+# Check tokenization type
+if [ "${lang}" != noinfo ]; then
+ token_listdir=data/${lang}_token_list
+else
+ token_listdir=data/token_list
+fi
+bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
+bpeprefix="${bpedir}"/bpe
+bpemodel="${bpeprefix}".model
+bpetoken_list="${bpedir}"/tokens.txt
+chartoken_list="${token_listdir}"/char/tokens.txt
+# NOTE: keep for future development.
+# shellcheck disable=SC2034
+wordtoken_list="${token_listdir}"/word/tokens.txt
+
+if [ "${token_type}" = bpe ]; then
+ token_list="${bpetoken_list}"
+elif [ "${token_type}" = char ]; then
+ token_list="${chartoken_list}"
+ bpemodel=none
+elif [ "${token_type}" = word ]; then
+ token_list="${wordtoken_list}"
+ bpemodel=none
+else
+ log "Error: not supported --token_type '${token_type}'"
+ exit 2
+fi
+if ${use_word_lm}; then
+ log "Error: Word LM is not supported yet"
+ exit 2
+
+ lm_token_list="${wordtoken_list}"
+ lm_token_type=word
+else
+ lm_token_list="${token_list}"
+ lm_token_type="${token_type}"
+fi
+
+
+# Set tag for naming of model directory
+if [ -z "${asr_tag}" ]; then
+ if [ -n "${asr_config}" ]; then
+ asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
+ else
+ asr_tag="train_${feats_type}"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ asr_tag+="_${lang}_${token_type}"
+ else
+ asr_tag+="_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${asr_args}" ]; then
+ asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_tag+="_sp"
+ fi
+fi
+if [ -z "${lm_tag}" ]; then
+ if [ -n "${lm_config}" ]; then
+ lm_tag="$(basename "${lm_config}" .yaml)"
+ else
+ lm_tag="train"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ lm_tag+="_${lang}_${lm_token_type}"
+ else
+ lm_tag+="_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${lm_args}" ]; then
+ lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+fi
+
+# The directory used for collect-stats mode
+if [ -z "${asr_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
+ else
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_stats_dir+="${nbpe}"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_stats_dir+="_sp"
+ fi
+fi
+if [ -z "${lm_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
+ else
+ lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_stats_dir+="${nbpe}"
+ fi
+fi
+# The directory used for training commands
+if [ -z "${asr_exp}" ]; then
+ asr_exp="${expdir}/asr_${asr_tag}"
+fi
+if [ -z "${lm_exp}" ]; then
+ lm_exp="${expdir}/lm_${lm_tag}"
+fi
+
+
+if [ -z "${inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ inference_tag=inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${inference_args}" ]; then
+ inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+if [ -z "${sa_asr_inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ sa_asr_inference_tag=sa_asr_inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${sa_asr_inference_args}" ]; then
+ sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+train_cmd="run.pl"
+cuda_cmd="run.pl"
+decode_cmd="run.pl"
+
+# ========================== Main stages start from here. ==========================
+
+if ! "${skip_data_prep}"; then
+
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc."
+
+ ./local/alimeeting_data_prep.sh --tgt Test
+ ./local/alimeeting_data_prep.sh --tgt Eval
+ ./local/alimeeting_data_prep.sh --tgt Train
+ fi
+
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ if [ -n "${speed_perturb_factors}" ]; then
+ log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
+ for factor in ${speed_perturb_factors}; do
+ if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
+ local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
+ _dirs+="data/${train_set}_sp${factor} "
+ else
+ # If speed factor is 1, same as the original
+ _dirs+="data/${train_set} "
+ fi
+ done
+ local/combine_data.sh "data/${train_set}_sp" ${_dirs}
+ else
+ log "Skip stage 2: Speed perturbation"
+ fi
+ fi
+
+ if [ -n "${speed_perturb_factors}" ]; then
+ train_set="${train_set}_sp"
+ fi
+
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ if [ "${feats_type}" = raw ]; then
+ log "Stage 3: Format wav.scp: data/ -> ${data_feats}"
+
+ # ====== Recreating "wav.scp" ======
+ # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
+ # shouldn't be used in training process.
+ # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
+ # and it can also change the audio-format and sampling rate.
+ # If nothing is need, then format_wav_scp.sh does nothing:
+ # i.e. the input file format and rate is same as the output.
+
+ for dset in "${train_set}" "${valid_set}" "${test_sets}" ; do
+ if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then
+ _suf="/org"
+ else
+ if [ "${dset}" = "${test_sets}" ] && [ "${test_sets}" = "Test_Ali_far" ]; then
+ _suf="/org"
+ else
+ _suf=""
+ fi
+ fi
+ local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+
+ if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
+ cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+ fi
+
+ rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
+ _opts=
+ if [ -e data/"${dset}"/segments ]; then
+ # "segments" is used for splitting wav files which are written in "wav".scp
+ # into utterances. The file format of segments:
+ # <segment_id> <record_id> <start_time> <end_time>
+ # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
+ # Where the time is written in seconds.
+ _opts+="--segments data/${dset}/segments "
+ fi
+ # shellcheck disable=SC2086
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
+ "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
+
+ echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
+ done
+
+ else
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+ fi
+ fi
+
+
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ log "Stage 4: Remove long/short data: ${data_feats}/org -> ${data_feats}"
+
+ # NOTE(kamo): Not applying to test_sets to keep original data
+ if [ "${test_sets}" = "Test_Ali_far" ]; then
+ rm_dset="${train_set} ${valid_set} ${test_sets}"
+ else
+ rm_dset="${train_set} ${valid_set}"
+ fi
+
+ for dset in $rm_dset; do
+
+ # Copy data dir
+ local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
+ cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
+
+ # Remove short utterances
+ _feats_type="$(<${data_feats}/${dset}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))")
+ _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))")
+
+ # utt2num_samples is created by format_wav_scp.sh
+ <"${data_feats}/org/${dset}/utt2num_samples" \
+ awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
+ '{ if ($2 > min_length && $2 < max_length ) print $0; }' \
+ >"${data_feats}/${dset}/utt2num_samples"
+ <"${data_feats}/org/${dset}/wav.scp" \
+ utils/filter_scp.pl "${data_feats}/${dset}/utt2num_samples" \
+ >"${data_feats}/${dset}/wav.scp"
+ else
+ # Get frame shift in ms from conf/fbank.conf
+ _frame_shift=
+ if [ -f conf/fbank.conf ] && [ "$(<conf/fbank.conf grep -c frame-shift)" -gt 0 ]; then
+ # Assume using conf/fbank.conf for feature extraction
+ _frame_shift="$(<conf/fbank.conf grep frame-shift | sed -e 's/[-a-z =]*\([0-9]*\)/\1/g')"
+ fi
+ if [ -z "${_frame_shift}" ]; then
+ # If not existing, use the default number in Kaldi (=10ms).
+ # If you are using different number, you have to change the following value manually.
+ _frame_shift=10
+ fi
+
+ _min_length=$(python3 -c "print(int(${min_wav_duration} / ${_frame_shift} * 1000))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} / ${_frame_shift} * 1000))")
+
+ cp "${data_feats}/org/${dset}/feats_dim" "${data_feats}/${dset}/feats_dim"
+ <"${data_feats}/org/${dset}/feats_shape" awk -F, ' { print $1 } ' \
+ | awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
+ '{ if ($2 > min_length && $2 < max_length) print $0; }' \
+ >"${data_feats}/${dset}/feats_shape"
+ <"${data_feats}/org/${dset}/feats.scp" \
+ utils/filter_scp.pl "${data_feats}/${dset}/feats_shape" \
+ >"${data_feats}/${dset}/feats.scp"
+ fi
+
+ # Remove empty text
+ <"${data_feats}/org/${dset}/text" \
+ awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
+
+ # fix_data_dir.sh leaves only utts which exist in all files
+ local/fix_data_dir.sh "${data_feats}/${dset}"
+
+ # generate uttid
+ cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
+
+ if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
+ # filter utt2spk_all_fifo
+ python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+ fi
+ done
+
+ # shellcheck disable=SC2002
+ cat ${lm_train_text} | awk ' { if( NF != 1 ) print $0; } ' > "${data_feats}/lm_train.txt"
+ fi
+
+
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ log "Stage 5: Dictionary Preparation"
+ mkdir -p data/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ num_token=$(cat ${token_list} | wc -l)
+ echo "<unk>" >> ${token_list}
+ vocab_size=$(cat ${token_list} | wc -l)
+ fi
+
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ log "Stage 6: Generate speaker settings"
+ mkdir -p "profile_log"
+ for dset in "${train_set}" "${valid_set}" "${test_sets}"; do
+ # generate text_id spk2id
+ python local/process_sot_fifo_textchar2spk.py --path ${data_feats}/${dset}
+ log "Successfully generate ${data_feats}/${dset}/text_id ${data_feats}/${dset}/spk2id"
+ # generate text_id_train for sot
+ python local/process_text_id.py ${data_feats}/${dset}
+ log "Successfully generate ${data_feats}/${dset}/text_id_train"
+ # generate oracle_embedding from single-speaker audio segment
+ log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log"
+ python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
+ log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
+ # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+ if [ "${dset}" = "${train_set}" ]; then
+ python local/gen_oracle_profile_padding.py ${data_feats}/${dset}
+ log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_padding.scp)"
+ else
+ python local/gen_oracle_profile_nopadding.py ${data_feats}/${dset}
+ log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_nopadding.scp)"
+ fi
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
+ log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log"
+ python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+ log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
+ fi
+
+ done
+ fi
+
+else
+ log "Skip the stages for data preparation"
+fi
+
+
+# ========================== Data preparation is done here. ==========================
+
+
+if ! "${skip_train}"; then
+ if "${use_lm}"; then
+ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ log "Stage 7: LM collect stats: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
+
+ _opts=
+ if [ -n "${lm_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.lm_train --print_config --optim adam
+ _opts+="--config ${lm_config} "
+ fi
+
+ # 1. Split the key file
+ _logdir="${lm_stats_dir}/logdir"
+ mkdir -p "${_logdir}"
+ # Get the minimum number among ${nj} and the number lines of input files
+ _nj=$(min "${nj}" "$(<${data_feats}/lm_train.txt wc -l)" "$(<${lm_dev_text} wc -l)")
+
+ key_file="${data_feats}/lm_train.txt"
+ split_scps=""
+ for n in $(seq ${_nj}); do
+ split_scps+=" ${_logdir}/train.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ key_file="${lm_dev_text}"
+ split_scps=""
+ for n in $(seq ${_nj}); do
+ split_scps+=" ${_logdir}/dev.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Generate run.sh
+ log "Generate '${lm_stats_dir}/run.sh'. You can resume the process from stage 6 using this script"
+ mkdir -p "${lm_stats_dir}"; echo "${run_args} --stage 6 \"\$@\"; exit \$?" > "${lm_stats_dir}/run.sh"; chmod +x "${lm_stats_dir}/run.sh"
+
+ # 3. Submit jobs
+ log "LM collect-stats started... log: '${_logdir}/stats.*.log'"
+ # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
+ # but it's used only for deciding the sample ids.
+ # shellcheck disable=SC2086
+ ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
+ ${python} -m funasr.bin.lm_train \
+ --collect_stats true \
+ --use_preprocessor true \
+ --bpemodel "${bpemodel}" \
+ --token_type "${lm_token_type}"\
+ --token_list "${lm_token_list}" \
+ --non_linguistic_symbols "${nlsyms_txt}" \
+ --cleaner "${cleaner}" \
+ --g2p "${g2p}" \
+ --train_data_path_and_name_and_type "${data_feats}/lm_train.txt,text,text" \
+ --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
+ --train_shape_file "${_logdir}/train.JOB.scp" \
+ --valid_shape_file "${_logdir}/dev.JOB.scp" \
+ --output_dir "${_logdir}/stats.JOB" \
+ ${_opts} ${lm_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
+
+ # 4. Aggregate shape files
+ _opts=
+ for i in $(seq "${_nj}"); do
+ _opts+="--input_dir ${_logdir}/stats.${i} "
+ done
+ # shellcheck disable=SC2086
+ ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${lm_stats_dir}"
+
+ # Append the num-tokens at the last dimensions. This is used for batch-bins count
+ <"${lm_stats_dir}/train/text_shape" \
+ awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
+ >"${lm_stats_dir}/train/text_shape.${lm_token_type}"
+
+ <"${lm_stats_dir}/valid/text_shape" \
+ awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
+ >"${lm_stats_dir}/valid/text_shape.${lm_token_type}"
+ fi
+
+
+ if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
+ log "Stage 8: LM Training: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
+
+ _opts=
+ if [ -n "${lm_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.lm_train --print_config --optim adam
+ _opts+="--config ${lm_config} "
+ fi
+
+ if [ "${num_splits_lm}" -gt 1 ]; then
+ # If you met a memory error when parsing text files, this option may help you.
+ # The corpus is split into subsets and each subset is used for training one by one in order,
+ # so the memory footprint can be limited to the memory required for each dataset.
+
+ _split_dir="${lm_stats_dir}/splits${num_splits_lm}"
+ if [ ! -f "${_split_dir}/.done" ]; then
+ rm -f "${_split_dir}/.done"
+ ${python} -m espnet2.bin.split_scps \
+ --scps "${data_feats}/lm_train.txt" "${lm_stats_dir}/train/text_shape.${lm_token_type}" \
+ --num_splits "${num_splits_lm}" \
+ --output_dir "${_split_dir}"
+ touch "${_split_dir}/.done"
+ else
+ log "${_split_dir}/.done exists. Spliting is skipped"
+ fi
+
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/lm_train.txt,text,text "
+ _opts+="--train_shape_file ${_split_dir}/text_shape.${lm_token_type} "
+ _opts+="--multiple_iterator true "
+
+ else
+ _opts+="--train_data_path_and_name_and_type ${data_feats}/lm_train.txt,text,text "
+ _opts+="--train_shape_file ${lm_stats_dir}/train/text_shape.${lm_token_type} "
+ fi
+
+ # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+
+ log "Generate '${lm_exp}/run.sh'. You can resume the process from stage 8 using this script"
+ mkdir -p "${lm_exp}"; echo "${run_args} --stage 8 \"\$@\"; exit \$?" > "${lm_exp}/run.sh"; chmod +x "${lm_exp}/run.sh"
+
+ log "LM training started... log: '${lm_exp}/train.log'"
+ if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+ # SGE can't include "/" in a job name
+ jobname="$(basename ${lm_exp})"
+ else
+ jobname="${lm_exp}/train.log"
+ fi
+
+ mkdir -p ${lm_exp}
+ mkdir -p ${lm_exp}/log
+ INIT_FILE=${lm_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+ lm_train.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --bpemodel ${bpemodel} \
+ --token_type ${token_type} \
+ --token_list ${token_list} \
+ --non_linguistic_symbols ${nlsyms_txt} \
+ --cleaner ${cleaner} \
+ --g2p ${g2p} \
+ --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
+ --valid_shape_file "${lm_stats_dir}/valid/text_shape.${lm_token_type}" \
+ --resume true \
+ --output_dir ${lm_exp} \
+ --config $lm_config \
+ --ngpu $ngpu \
+ --num_worker_count 1 \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $ngpu \
+ --dist_rank $rank \
+ --local_rank $local_rank \
+ ${_opts} 1> ${lm_exp}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+ fi
+
+
+ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
+ log "Stage 9: Calc perplexity: ${lm_test_text}"
+ _opts=
+ # TODO(kamo): Parallelize?
+ log "Perplexity calculation started... log: '${lm_exp}/perplexity_test/lm_calc_perplexity.log'"
+ # shellcheck disable=SC2086
+ CUDA_VISIBLE_DEVICES=${device}\
+ ${cuda_cmd} --gpu "${ngpu}" "${lm_exp}"/perplexity_test/lm_calc_perplexity.log \
+ ${python} -m funasr.bin.lm_calc_perplexity \
+ --ngpu "${ngpu}" \
+ --data_path_and_name_and_type "${lm_test_text},text,text" \
+ --train_config "${lm_exp}"/config.yaml \
+ --model_file "${lm_exp}/${inference_lm}" \
+ --output_dir "${lm_exp}/perplexity_test" \
+ ${_opts}
+ log "PPL: ${lm_test_text}: $(cat ${lm_exp}/perplexity_test/ppl)"
+
+ fi
+
+ else
+ log "Stage 7-9: Skip lm-related stages: use_lm=${use_lm}"
+ fi
+
+
+ if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
+ _asr_train_dir="${data_feats}/${train_set}"
+ _asr_valid_dir="${data_feats}/${valid_set}"
+ log "Stage 10: ASR collect stats: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+ _opts=
+ if [ -n "${asr_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.asr_train --print_config --optim adam
+ _opts+="--config ${asr_config} "
+ fi
+
+ _feats_type="$(<${_asr_train_dir}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ # "sound" supports "wav", "flac", etc.
+ _type=sound
+ fi
+ _opts+="--frontend_conf fs=${fs} "
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ _input_size="$(<${_asr_train_dir}/feats_dim)"
+ _opts+="--input_size=${_input_size} "
+ fi
+
+ # 1. Split the key file
+ _logdir="${asr_stats_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ # Get the minimum number among ${nj} and the number lines of input files
+ _nj=$(min "${nj}" "$(<${_asr_train_dir}/${_scp} wc -l)" "$(<${_asr_valid_dir}/${_scp} wc -l)")
+
+ key_file="${_asr_train_dir}/${_scp}"
+ split_scps=""
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/train.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ key_file="${_asr_valid_dir}/${_scp}"
+ split_scps=""
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/valid.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Generate run.sh
+ log "Generate '${asr_stats_dir}/run.sh'. You can resume the process from stage 9 using this script"
+ mkdir -p "${asr_stats_dir}"; echo "${run_args} --stage 9 \"\$@\"; exit \$?" > "${asr_stats_dir}/run.sh"; chmod +x "${asr_stats_dir}/run.sh"
+
+ # 3. Submit jobs
+ log "ASR collect-stats started... log: '${_logdir}/stats.*.log'"
+
+ # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
+ # but it's used only for deciding the sample ids.
+
+ # shellcheck disable=SC2086
+ ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
+ ${python} -m funasr.bin.asr_train \
+ --collect_stats true \
+ --mc true \
+ --use_preprocessor true \
+ --bpemodel "${bpemodel}" \
+ --token_type "${token_type}" \
+ --token_list "${token_list}" \
+ --split_with_space false \
+ --non_linguistic_symbols "${nlsyms_txt}" \
+ --cleaner "${cleaner}" \
+ --g2p "${g2p}" \
+ --train_data_path_and_name_and_type "${_asr_train_dir}/${_scp},speech,${_type}" \
+ --train_data_path_and_name_and_type "${_asr_train_dir}/text,text,text" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
+ --train_shape_file "${_logdir}/train.JOB.scp" \
+ --valid_shape_file "${_logdir}/valid.JOB.scp" \
+ --output_dir "${_logdir}/stats.JOB" \
+ ${_opts} ${asr_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
+
+ # 4. Aggregate shape files
+ _opts=
+ for i in $(seq "${_nj}"); do
+ _opts+="--input_dir ${_logdir}/stats.${i} "
+ done
+ # shellcheck disable=SC2086
+ ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${asr_stats_dir}"
+
+ # Append the num-tokens at the last dimensions. This is used for batch-bins count
+ <"${asr_stats_dir}/train/text_shape" \
+ awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
+ >"${asr_stats_dir}/train/text_shape.${token_type}"
+
+ <"${asr_stats_dir}/valid/text_shape" \
+ awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
+ >"${asr_stats_dir}/valid/text_shape.${token_type}"
+ fi
+
+
+ if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
+ _asr_train_dir="${data_feats}/${train_set}"
+ _asr_valid_dir="${data_feats}/${valid_set}"
+ log "Stage 11: ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+ _opts=
+ if [ -n "${asr_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.asr_train --print_config --optim adam
+ _opts+="--config ${asr_config} "
+ fi
+
+ _feats_type="$(<${_asr_train_dir}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ # "sound" supports "wav", "flac", etc.
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ _opts+="--frontend_conf fs=${fs} "
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ _input_size="$(<${_asr_train_dir}/feats_dim)"
+ _opts+="--input_size=${_input_size} "
+
+ fi
+ if [ "${feats_normalize}" = global_mvn ]; then
+ # Default normalization is utterance_mvn and changes to global_mvn
+ _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
+ fi
+
+ if [ "${num_splits_asr}" -gt 1 ]; then
+ # If you met a memory error when parsing text files, this option may help you.
+ # The corpus is split into subsets and each subset is used for training one by one in order,
+ # so the memory footprint can be limited to the memory required for each dataset.
+
+ _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
+ if [ ! -f "${_split_dir}/.done" ]; then
+ rm -f "${_split_dir}/.done"
+ ${python} -m espnet2.bin.split_scps \
+ --scps \
+ "${_asr_train_dir}/${_scp}" \
+ "${_asr_train_dir}/text" \
+ "${asr_stats_dir}/train/speech_shape" \
+ "${asr_stats_dir}/train/text_shape.${token_type}" \
+ --num_splits "${num_splits_asr}" \
+ --output_dir "${_split_dir}"
+ touch "${_split_dir}/.done"
+ else
+ log "${_split_dir}/.done exists. Spliting is skipped"
+ fi
+
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
+ _opts+="--train_shape_file ${_split_dir}/speech_shape "
+ _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
+ _opts+="--multiple_iterator true "
+
+ else
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
+ fi
+
+ # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
+ # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
+
+ # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+ log "ASR training started... log: '${asr_exp}/log/train.log'"
+ # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+ # # SGE can't include "/" in a job name
+ # jobname="$(basename ${asr_exp})"
+ # else
+ # jobname="${asr_exp}/train.log"
+ # fi
+
+ mkdir -p ${asr_exp}
+ mkdir -p ${asr_exp}/log
+ INIT_FILE=${asr_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+ asr_train.py \
+ --mc true \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --bpemodel ${bpemodel} \
+ --token_type ${token_type} \
+ --token_list ${token_list} \
+ --split_with_space false \
+ --non_linguistic_symbols ${nlsyms_txt} \
+ --cleaner ${cleaner} \
+ --g2p ${g2p} \
+ --valid_data_path_and_name_and_type ${_asr_valid_dir}/${_scp},speech,${_type} \
+ --valid_data_path_and_name_and_type ${_asr_valid_dir}/text,text,text \
+ --valid_shape_file ${asr_stats_dir}/valid/speech_shape \
+ --valid_shape_file ${asr_stats_dir}/valid/text_shape.${token_type} \
+ --resume true \
+ --output_dir ${asr_exp} \
+ --config $asr_config \
+ --ngpu $ngpu \
+ --num_worker_count 1 \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $ngpu \
+ --dist_rank $rank \
+ --local_rank $local_rank \
+ ${_opts} 1> ${asr_exp}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+ fi
+
+ if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
+ _asr_train_dir="${data_feats}/${train_set}"
+ _asr_valid_dir="${data_feats}/${valid_set}"
+ log "Stage 12: SA-ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
+
+ _opts=
+ if [ -n "${sa_asr_config}" ]; then
+ # To generate the config file: e.g.
+ # % python3 -m espnet2.bin.asr_train --print_config --optim adam
+ _opts+="--config ${sa_asr_config} "
+ fi
+
+ _feats_type="$(<${_asr_train_dir}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ # "sound" supports "wav", "flac", etc.
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ _opts+="--frontend_conf fs=${fs} "
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ _input_size="$(<${_asr_train_dir}/feats_dim)"
+ _opts+="--input_size=${_input_size} "
+
+ fi
+ if [ "${feats_normalize}" = global_mvn ]; then
+ # Default normalization is utterance_mvn and changes to global_mvn
+ _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
+ fi
+
+ if [ "${num_splits_asr}" -gt 1 ]; then
+ # If you met a memory error when parsing text files, this option may help you.
+ # The corpus is split into subsets and each subset is used for training one by one in order,
+ # so the memory footprint can be limited to the memory required for each dataset.
+
+ _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
+ if [ ! -f "${_split_dir}/.done" ]; then
+ rm -f "${_split_dir}/.done"
+ ${python} -m espnet2.bin.split_scps \
+ --scps \
+ "${_asr_train_dir}/${_scp}" \
+ "${_asr_train_dir}/text" \
+ "${asr_stats_dir}/train/speech_shape" \
+ "${asr_stats_dir}/train/text_shape.${token_type}" \
+ --num_splits "${num_splits_asr}" \
+ --output_dir "${_split_dir}"
+ touch "${_split_dir}/.done"
+ else
+ log "${_split_dir}/.done exists. Spliting is skipped"
+ fi
+
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/text_id_train,text_id,text_int "
+ _opts+="--train_data_path_and_name_and_type ${_split_dir}/oracle_profile_padding.scp,profile,npy "
+ _opts+="--train_shape_file ${_split_dir}/speech_shape "
+ _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
+ _opts+="--multiple_iterator true "
+
+ else
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/oracle_profile_padding.scp,profile,npy "
+ _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text_id_train,text_id,text_int "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
+ _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
+ fi
+
+ # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
+ # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
+
+ # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
+ log "SA-ASR training started... log: '${sa_asr_exp}/log/train.log'"
+ # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
+ # # SGE can't include "/" in a job name
+ # jobname="$(basename ${asr_exp})"
+ # else
+ # jobname="${asr_exp}/train.log"
+ # fi
+
+ mkdir -p ${sa_asr_exp}
+ mkdir -p ${sa_asr_exp}/log
+ INIT_FILE=${sa_asr_exp}/ddp_init
+
+ if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+ # download xvector extractor model file
+ python local/download_xvector_model.py exp
+ log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+ fi
+
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $device | cut -d',' -f$[$i+1])
+ sa_asr_train.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --unused_parameters true \
+ --bpemodel ${bpemodel} \
+ --token_type ${token_type} \
+ --token_list ${token_list} \
+ --max_spk_num 4 \
+ --split_with_space false \
+ --non_linguistic_symbols ${nlsyms_txt} \
+ --cleaner ${cleaner} \
+ --g2p ${g2p} \
+ --allow_variable_data_keys true \
+ --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \
+ --valid_data_path_and_name_and_type "${_asr_valid_dir}/text_id_train,text_id,text_int" \
+ --valid_shape_file "${asr_stats_dir}/valid/speech_shape" \
+ --valid_shape_file "${asr_stats_dir}/valid/text_shape.${token_type}" \
+ --resume true \
+ --output_dir ${sa_asr_exp} \
+ --config $sa_asr_config \
+ --ngpu $ngpu \
+ --num_worker_count 1 \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $ngpu \
+ --dist_rank $rank \
+ --local_rank $local_rank \
+ ${_opts} 1> ${sa_asr_exp}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+ fi
+
+else
+ log "Skip the training stages"
+fi
+
+
+if ! "${skip_eval}"; then
+ if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
+ log "Stage 13: Decoding multi-talker ASR: training_dir=${asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$inference_nj
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 13 using this script"
+ mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 13 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${asr_exp}/${inference_tag}/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ echo $_nj
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/asr_inference.*.log'"
+
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode asr \
+ ${_opts}
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+
+ if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
+ log "Stage 14: Scoring multi-talker ASR"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${asr_exp}/${inference_tag}/${dset}"
+
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ done
+
+ fi
+
+ if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
+ log "Stage 15: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$inference_nj
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh'. You can resume the process from stage 15 using this script"
+ mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.oracle"; echo "${run_args} --stage 15 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+ # shellcheck disable=SC2086
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --data_path_and_name_and_type "${_data}/oracle_profile_nopadding.scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --allow_variable_data_keys true \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text text_id; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+ if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
+ log "Stage 16: Scoring SA-ASR (oracle profile)"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
+
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+
+ done
+
+ fi
+
+ if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then
+ log "Stage 17: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$inference_nj
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
+ mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+ # shellcheck disable=SC2086
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --allow_variable_data_keys true \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text text_id; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+ if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then
+ log "Stage 18: Scoring SA-ASR (cluster profile)"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+
+ done
+
+ fi
+
+else
+ log "Skip the evaluation stages"
+fi
+
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
new file mode 100755
index 0000000..a23215c
--- /dev/null
+++ b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
@@ -0,0 +1,591 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+min() {
+ local a b
+ a=$1
+ for b in "$@"; do
+ if [ "${b}" -le "${a}" ]; then
+ a="${b}"
+ fi
+ done
+ echo "${a}"
+}
+SECONDS=0
+
+# General configuration
+stage=1 # Processes starts from the specified stage.
+stop_stage=10000 # Processes is stopped at the specified stage.
+skip_data_prep=false # Skip data preparation stages.
+skip_train=false # Skip training stages.
+skip_eval=false # Skip decoding and evaluation stages.
+skip_upload=true # Skip packing and uploading stages.
+ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
+num_nodes=1 # The number of nodes.
+nj=16 # The number of parallel jobs.
+inference_nj=16 # The number of parallel jobs in decoding.
+gpu_inference=false # Whether to perform gpu decoding.
+njob_infer=4
+dumpdir=dump2 # Directory to dump features.
+expdir=exp # Directory to save experiments.
+python=python3 # Specify python to execute espnet commands.
+device=0
+
+# Data preparation related
+local_data_opts= # The options given to local/data.sh.
+
+# Speed perturbation related
+speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
+
+# Feature extraction related
+feats_type=raw # Feature type (raw or fbank_pitch).
+audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
+fs=16000 # Sampling rate.
+min_wav_duration=0.1 # Minimum duration in second.
+max_wav_duration=20 # Maximum duration in second.
+
+# Tokenization related
+token_type=bpe # Tokenization type (char or bpe).
+nbpe=30 # The number of BPE vocabulary.
+bpemode=unigram # Mode of BPE (unigram or bpe).
+oov="<unk>" # Out of vocabulary symbol.
+blank="<blank>" # CTC blank symbol
+sos_eos="<sos/eos>" # sos and eos symbole
+bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
+bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
+bpe_char_cover=1.0 # character coverage when modeling BPE
+
+# Language model related
+use_lm=true # Use language model for ASR decoding.
+lm_tag= # Suffix to the result dir for language model training.
+lm_exp= # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored.
+lm_stats_dir= # Specify the direcotry path for LM statistics.
+lm_config= # Config for language model training.
+lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in lm config.
+use_word_lm=false # Whether to use word language model.
+num_splits_lm=1 # Number of splitting for lm corpus.
+# shellcheck disable=SC2034
+word_vocab_size=10000 # Size of word vocabulary.
+
+# ASR model related
+asr_tag= # Suffix to the result dir for asr model training.
+asr_exp= # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored.
+sa_asr_exp=
+asr_stats_dir= # Specify the direcotry path for ASR statistics.
+asr_config= # Config for asr model training.
+sa_asr_config=
+asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
+ # Note that it will overwrite args in asr config.
+feats_normalize=global_mvn # Normalizaton layer type.
+num_splits_asr=1 # Number of splitting for lm corpus.
+
+# Decoding related
+inference_tag= # Suffix to the result dir for decoding.
+inference_config= # Config for decoding.
+inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
+ # Note that it will overwrite args in inference config.
+sa_asr_inference_tag=
+sa_asr_inference_args=
+
+inference_lm=valid.loss.ave.pb # Language modle path for decoding.
+inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
+ # e.g.
+ # inference_asr_model=train.loss.best.pth
+ # inference_asr_model=3epoch.pth
+ # inference_asr_model=valid.acc.best.pth
+ # inference_asr_model=valid.loss.ave.pth
+inference_sa_asr_model=valid.acc_spk.ave.pb
+download_model= # Download a model from Model Zoo and use it for decoding.
+
+# [Task dependent] Set the datadir name created by local/data.sh
+train_set= # Name of training set.
+valid_set= # Name of validation set used for monitoring/tuning network training.
+test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
+bpe_train_text= # Text file path of bpe training set.
+lm_train_text= # Text file path of language model training set.
+lm_dev_text= # Text file path of language model development set.
+lm_test_text= # Text file path of language model evaluation set.
+nlsyms_txt=none # Non-linguistic symbol list if existing.
+cleaner=none # Text cleaner.
+g2p=none # g2p method (needed if token_type=phn).
+lang=zh # The language type of corpus.
+score_opts= # The options given to sclite scoring
+local_score_opts= # The options given to local/score.sh.
+
+help_message=$(cat << EOF
+Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
+
+Options:
+ # General configuration
+ --stage # Processes starts from the specified stage (default="${stage}").
+ --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
+ --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
+ --skip_train # Skip training stages (default="${skip_train}").
+ --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
+ --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
+ --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
+ --num_nodes # The number of nodes (default="${num_nodes}").
+ --nj # The number of parallel jobs (default="${nj}").
+ --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
+ --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
+ --dumpdir # Directory to dump features (default="${dumpdir}").
+ --expdir # Directory to save experiments (default="${expdir}").
+ --python # Specify python to execute espnet commands (default="${python}").
+ --device # Which GPUs are use for local training (defalut="${device}").
+
+ # Data preparation related
+ --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
+
+ # Speed perturbation related
+ --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
+
+ # Feature extraction related
+ --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
+ --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
+ --fs # Sampling rate (default="${fs}").
+ --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
+ --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
+
+ # Tokenization related
+ --token_type # Tokenization type (char or bpe, default="${token_type}").
+ --nbpe # The number of BPE vocabulary (default="${nbpe}").
+ --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
+ --oov # Out of vocabulary symbol (default="${oov}").
+ --blank # CTC blank symbol (default="${blank}").
+ --sos_eos # sos and eos symbole (default="${sos_eos}").
+ --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
+ --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
+ --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
+
+ # Language model related
+ --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
+ --lm_exp # Specify the direcotry path for LM experiment.
+ # If this option is specified, lm_tag is ignored (default="${lm_exp}").
+ --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
+ --lm_config # Config for language model training (default="${lm_config}").
+ --lm_args # Arguments for language model training (default="${lm_args}").
+ # e.g., --lm_args "--max_epoch 10"
+ # Note that it will overwrite args in lm config.
+ --use_word_lm # Whether to use word language model (default="${use_word_lm}").
+ --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
+ --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
+
+ # ASR model related
+ --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
+ --asr_exp # Specify the direcotry path for ASR experiment.
+ # If this option is specified, asr_tag is ignored (default="${asr_exp}").
+ --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
+ --asr_config # Config for asr model training (default="${asr_config}").
+ --asr_args # Arguments for asr model training (default="${asr_args}").
+ # e.g., --asr_args "--max_epoch 10"
+ # Note that it will overwrite args in asr config.
+ --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
+ --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
+
+ # Decoding related
+ --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
+ --inference_config # Config for decoding (default="${inference_config}").
+ --inference_args # Arguments for decoding (default="${inference_args}").
+ # e.g., --inference_args "--lm_weight 0.1"
+ # Note that it will overwrite args in inference config.
+ --inference_lm # Language modle path for decoding (default="${inference_lm}").
+ --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
+ --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+
+ # [Task dependent] Set the datadir name created by local/data.sh
+ --train_set # Name of training set (required).
+ --valid_set # Name of validation set used for monitoring/tuning network training (required).
+ --test_sets # Names of test sets.
+ # Multiple items (e.g., both dev and eval sets) can be specified (required).
+ --bpe_train_text # Text file path of bpe training set.
+ --lm_train_text # Text file path of language model training set.
+ --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
+ --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
+ --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
+ --cleaner # Text cleaner (default="${cleaner}").
+ --g2p # g2p method (default="${g2p}").
+ --lang # The language type of corpus (default=${lang}).
+ --score_opts # The options given to sclite scoring (default="{score_opts}").
+ --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
+EOF
+)
+
+log "$0 $*"
+# Save command line args for logging (they will be lost after utils/parse_options.sh)
+run_args=$(python -m funasr.utils.cli_utils $0 "$@")
+. utils/parse_options.sh
+
+if [ $# -ne 0 ]; then
+ log "${help_message}"
+ log "Error: No positional arguments are required."
+ exit 2
+fi
+
+. ./path.sh
+
+
+# Check required arguments
+[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
+[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
+[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
+
+# Check feature type
+if [ "${feats_type}" = raw ]; then
+ data_feats=${dumpdir}/raw
+elif [ "${feats_type}" = fbank_pitch ]; then
+ data_feats=${dumpdir}/fbank_pitch
+elif [ "${feats_type}" = fbank ]; then
+ data_feats=${dumpdir}/fbank
+elif [ "${feats_type}" == extracted ]; then
+ data_feats=${dumpdir}/extracted
+else
+ log "${help_message}"
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+fi
+
+# Use the same text as ASR for bpe training if not specified.
+[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
+# Use the same text as ASR for lm training if not specified.
+[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
+# Use the text of the 1st evaldir if lm_test is not specified
+[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
+
+# Check tokenization type
+if [ "${lang}" != noinfo ]; then
+ token_listdir=data/${lang}_token_list
+else
+ token_listdir=data/token_list
+fi
+bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
+bpeprefix="${bpedir}"/bpe
+bpemodel="${bpeprefix}".model
+bpetoken_list="${bpedir}"/tokens.txt
+chartoken_list="${token_listdir}"/char/tokens.txt
+# NOTE: keep for future development.
+# shellcheck disable=SC2034
+wordtoken_list="${token_listdir}"/word/tokens.txt
+
+if [ "${token_type}" = bpe ]; then
+ token_list="${bpetoken_list}"
+elif [ "${token_type}" = char ]; then
+ token_list="${chartoken_list}"
+ bpemodel=none
+elif [ "${token_type}" = word ]; then
+ token_list="${wordtoken_list}"
+ bpemodel=none
+else
+ log "Error: not supported --token_type '${token_type}'"
+ exit 2
+fi
+if ${use_word_lm}; then
+ log "Error: Word LM is not supported yet"
+ exit 2
+
+ lm_token_list="${wordtoken_list}"
+ lm_token_type=word
+else
+ lm_token_list="${token_list}"
+ lm_token_type="${token_type}"
+fi
+
+
+# Set tag for naming of model directory
+if [ -z "${asr_tag}" ]; then
+ if [ -n "${asr_config}" ]; then
+ asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
+ else
+ asr_tag="train_${feats_type}"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ asr_tag+="_${lang}_${token_type}"
+ else
+ asr_tag+="_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${asr_args}" ]; then
+ asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_tag+="_sp"
+ fi
+fi
+if [ -z "${lm_tag}" ]; then
+ if [ -n "${lm_config}" ]; then
+ lm_tag="$(basename "${lm_config}" .yaml)"
+ else
+ lm_tag="train"
+ fi
+ if [ "${lang}" != noinfo ]; then
+ lm_tag+="_${lang}_${lm_token_type}"
+ else
+ lm_tag+="_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_tag+="${nbpe}"
+ fi
+ # Add overwritten arg's info
+ if [ -n "${lm_args}" ]; then
+ lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
+ fi
+fi
+
+# The directory used for collect-stats mode
+if [ -z "${asr_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
+ else
+ asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
+ fi
+ if [ "${token_type}" = bpe ]; then
+ asr_stats_dir+="${nbpe}"
+ fi
+ if [ -n "${speed_perturb_factors}" ]; then
+ asr_stats_dir+="_sp"
+ fi
+fi
+if [ -z "${lm_stats_dir}" ]; then
+ if [ "${lang}" != noinfo ]; then
+ lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
+ else
+ lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
+ fi
+ if [ "${lm_token_type}" = bpe ]; then
+ lm_stats_dir+="${nbpe}"
+ fi
+fi
+# The directory used for training commands
+if [ -z "${asr_exp}" ]; then
+ asr_exp="${expdir}/asr_${asr_tag}"
+fi
+if [ -z "${lm_exp}" ]; then
+ lm_exp="${expdir}/lm_${lm_tag}"
+fi
+
+
+if [ -z "${inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ inference_tag=inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${inference_args}" ]; then
+ inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+if [ -z "${sa_asr_inference_tag}" ]; then
+ if [ -n "${inference_config}" ]; then
+ sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
+ else
+ sa_asr_inference_tag=sa_asr_inference
+ fi
+ # Add overwritten arg's info
+ if [ -n "${sa_asr_inference_args}" ]; then
+ sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
+ fi
+ if "${use_lm}"; then
+ sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+ fi
+ sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
+fi
+
+train_cmd="run.pl"
+cuda_cmd="run.pl"
+decode_cmd="run.pl"
+
+# ========================== Main stages start from here. ==========================
+
+if ! "${skip_data_prep}"; then
+
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ if [ "${feats_type}" = raw ]; then
+ log "Stage 1: Format wav.scp: data/ -> ${data_feats}"
+
+ # ====== Recreating "wav.scp" ======
+ # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
+ # shouldn't be used in training process.
+ # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
+ # and it can also change the audio-format and sampling rate.
+ # If nothing is need, then format_wav_scp.sh does nothing:
+ # i.e. the input file format and rate is same as the output.
+
+ for dset in "${test_sets}" ; do
+
+ _suf=""
+
+ local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+
+ rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
+ _opts=
+ if [ -e data/"${dset}"/segments ]; then
+ # "segments" is used for splitting wav files which are written in "wav".scp
+ # into utterances. The file format of segments:
+ # <segment_id> <record_id> <start_time> <end_time>
+ # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
+ # Where the time is written in seconds.
+ _opts+="--segments data/${dset}/segments "
+ fi
+ # shellcheck disable=SC2086
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
+ "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
+
+ echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
+ done
+
+ else
+ log "Error: not supported: --feats_type ${feats_type}"
+ exit 2
+ fi
+ fi
+
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ log "Stage 2: Generate speaker profile by spectral-cluster"
+ mkdir -p "profile_log"
+ for dset in "${test_sets}"; do
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+ log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
+ done
+ fi
+
+else
+ log "Skip the stages for data preparation"
+fi
+
+
+# ========================== Data preparation is done here. ==========================
+
+if ! "${skip_eval}"; then
+
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+
+ if ${gpu_inference}; then
+ _cmd="${cuda_cmd}"
+ inference_nj=$[${ngpu}*${njob_infer}]
+ _ngpu=1
+
+ else
+ _cmd="${decode_cmd}"
+ inference_nj=$njob_infer
+ _ngpu=0
+ fi
+
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if "${use_lm}"; then
+ if "${use_word_lm}"; then
+ _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
+ else
+ _opts+="--lm_train_config ${lm_exp}/config.yaml "
+ _opts+="--lm_file ${lm_exp}/${inference_lm} "
+ fi
+ fi
+
+ # 2. Generate run.sh
+ log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
+ mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
+
+ for dset in ${test_sets}; do
+ _data="${data_feats}/${dset}"
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+ _logdir="${_dir}/logdir"
+ mkdir -p "${_logdir}"
+
+ _feats_type="$(<${_data}/feats_type)"
+ if [ "${_feats_type}" = raw ]; then
+ _scp=wav.scp
+ if [[ "${audio_format}" == *ark* ]]; then
+ _type=kaldi_ark
+ else
+ _type=sound
+ fi
+ else
+ _scp=feats.scp
+ _type=kaldi_ark
+ fi
+
+ # 1. Split the key file
+ key_file=${_data}/${_scp}
+ split_scps=""
+ _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+
+ # 2. Submit decoding jobs
+ log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
+ # shellcheck disable=SC2086
+ ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --nbest 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob_infer} \
+ --gpuid_list ${device} \
+ --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --allow_variable_data_keys true \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ # 3. Concatenates the output files from each jobs
+ for f in token token_int score text text_id; do
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | LC_ALL=C sort -k1 >"${_dir}/${f}"
+ done
+ done
+ fi
+
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ log "Stage 4: Generate SA-ASR results (cluster profile)"
+
+ for dset in ${test_sets}; do
+ _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
+
+ python local/process_text_spk_merge.py ${_dir}
+ done
+
+ fi
+
+else
+ log "Skip the evaluation stages"
+fi
+
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
new file mode 100644
index 0000000..88fdbc2
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
@@ -0,0 +1,6 @@
+beam_size: 20
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.6
+lm_weight: 0.3
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
new file mode 100644
index 0000000..7865763
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
@@ -0,0 +1,87 @@
+# network architecture
+frontend: default
+frontend_conf:
+ n_fft: 400
+ win_length: 400
+ hop_length: 160
+
+# encoder related
+encoder: conformer
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ rel_pos_type: latest
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder related
+decoder: transformer
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# ctc related
+ctc_conf:
+ ignore_nan_grad: true
+
+# hybrid CTC/attention
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# minibatch related
+batch_type: numel
+batch_bins: 10000000 # reduce/increase this number according to your GPU memory
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+ - valid
+ - acc
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
new file mode 100644
index 0000000..68520ae
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
@@ -0,0 +1,29 @@
+lm: transformer
+lm_conf:
+ pos_enc: null
+ embed_unit: 128
+ att_unit: 512
+ head: 8
+ unit: 2048
+ layer: 16
+ dropout_rate: 0.1
+
+# optimization related
+grad_clip: 5.0
+batch_type: numel
+batch_bins: 500000 # 4gpus * 500000
+accum_grad: 1
+max_epoch: 15 # 15epoch is enougth
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+best_model_criterion:
+- - valid
+ - loss
+ - min
+keep_nbest_models: 10 # 10 is good.
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
new file mode 100644
index 0000000..421d7df
--- /dev/null
+++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
@@ -0,0 +1,115 @@
+# network architecture
+frontend: default
+frontend_conf:
+ n_fft: 400
+ win_length: 400
+ hop_length: 160
+
+# encoder related
+asr_encoder: conformer
+asr_encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+spk_encoder: resnet34_diar
+spk_encoder_conf:
+ use_head_conv: true
+ batchnorm_momentum: 0.5
+ use_head_maxpool: false
+ num_nodes_pooling_layer: 256
+ layers_in_block:
+ - 3
+ - 4
+ - 6
+ - 3
+ filters_in_block:
+ - 32
+ - 64
+ - 128
+ - 256
+ pooling_type: statistic
+ num_nodes_resnet1: 256
+ num_nodes_last_layer: 256
+ batchnorm_momentum: 0.5
+
+# decoder related
+decoder: sa_decoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ asr_num_blocks: 6
+ spk_num_blocks: 3
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# hybrid CTC/attention
+model_conf:
+ spk_weight: 0.5
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+ctc_conf:
+ ignore_nan_grad: true
+
+# minibatch related
+batch_type: numel
+batch_bins: 10000000
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - acc
+ - max
+- - valid
+ - acc_spk
+ - max
+- - valid
+ - loss
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 8000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
new file mode 100755
index 0000000..7d39cdc
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -0,0 +1,162 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+help_messge=$(cat << EOF
+Usage: $0
+
+Options:
+ --no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
+ --tgt (string): Which set to process, test or train.
+EOF
+)
+
+SECONDS=0
+tgt=Train #Train or Eval
+
+
+log "$0 $*"
+echo $tgt
+. ./utils/parse_options.sh
+
+. ./path.sh
+
+AliMeeting="${PWD}/dataset"
+
+if [ $# -gt 2 ]; then
+ log "${help_message}"
+ exit 2
+fi
+
+
+if [ ! -d "${AliMeeting}" ]; then
+ log "Error: ${AliMeeting} is empty."
+ exit 2
+fi
+
+# To absolute path
+AliMeeting=$(cd ${AliMeeting}; pwd)
+echo $AliMeeting
+far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
+near_raw_dir=${AliMeeting}/${tgt}_Ali_near/
+
+far_dir=data/local/${tgt}_Ali_far
+near_dir=data/local/${tgt}_Ali_near
+far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
+mkdir -p $far_single_speaker_dir
+
+stage=1
+stop_stage=4
+mkdir -p $far_dir
+mkdir -p $near_dir
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ log "stage 1:process alimeeting near dir"
+
+ find -L $near_raw_dir/audio_dir -iname "*.wav" > $near_dir/wavlist
+ awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
+ find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
+ n1_wav=$(wc -l < $near_dir/wavlist)
+ n2_text=$(wc -l < $near_dir/textgrid.flist)
+ log near file found $n1_wav wav and $n2_text text.
+
+ paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
+
+ # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+ cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+
+ python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
+ cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
+ utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
+ #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
+ #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+ local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
+ utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
+ sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
+ sed -e 's/锛�//g' $near_dir/tmp1> $near_dir/tmp2
+ sed -e 's/锛�//g' $near_dir/tmp2> $near_dir/text
+
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ log "stage 2:process alimeeting far dir"
+
+ find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
+ awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
+ find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
+ n1_wav=$(wc -l < $far_dir/wavlist)
+ n2_text=$(wc -l < $far_dir/textgrid.flist)
+ log far file found $n1_wav wav and $n2_text text.
+
+ paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
+
+ cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+
+ python local/alimeeting_process_overlap_force.py --path $far_dir \
+ --no-overlap false --mars True \
+ --overlap_length 0.8 --max_length 7
+
+ cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
+ #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
+
+ local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
+ sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
+ sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
+ sed -e 's/锛�//g' $far_dir/tmp2> $far_dir/tmp3
+ sed -e 's/锛�//g' $far_dir/tmp3> $far_dir/text
+fi
+
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ log "stage 3: finali data process"
+
+ local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
+ local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+
+ sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+
+ # remove space in text
+ for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+ cp -r $far_dir/* $far_single_speaker_dir
+ mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
+ paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
+ python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
+
+ cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
+ local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+
+ ./local/fix_data_dir.sh $far_single_speaker_dir
+ local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+
+ # remove space in text
+ for x in ${tgt}_Ali_far_single_speaker; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
new file mode 100755
index 0000000..e3ce934
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
@@ -0,0 +1,129 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+help_messge=$(cat << EOF
+Usage: $0
+
+Options:
+ --no_overlap (bool): Whether to ignore the overlapping utterance in the training set.
+ --tgt (string): Which set to process, test or train.
+EOF
+)
+
+SECONDS=0
+tgt=Train #Train or Eval
+
+
+log "$0 $*"
+echo $tgt
+. ./utils/parse_options.sh
+
+. ./path.sh
+
+AliMeeting="${PWD}/dataset"
+
+if [ $# -gt 2 ]; then
+ log "${help_message}"
+ exit 2
+fi
+
+
+if [ ! -d "${AliMeeting}" ]; then
+ log "Error: ${AliMeeting} is empty."
+ exit 2
+fi
+
+# To absolute path
+AliMeeting=$(cd ${AliMeeting}; pwd)
+echo $AliMeeting
+far_raw_dir=${AliMeeting}/${tgt}_Ali_far/
+
+far_dir=data/local/${tgt}_Ali_far
+far_single_speaker_dir=data/local/${tgt}_Ali_far_correct_single_speaker
+mkdir -p $far_single_speaker_dir
+
+stage=1
+stop_stage=3
+mkdir -p $far_dir
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ log "stage 1:process alimeeting far dir"
+
+ find -L $far_raw_dir/audio_dir -iname "*.wav" > $far_dir/wavlist
+ awk -F '/' '{print $NF}' $far_dir/wavlist | awk -F '.' '{print $1}' > $far_dir/uttid
+ find -L $far_raw_dir/textgrid_dir -iname "*.TextGrid" > $far_dir/textgrid.flist
+ n1_wav=$(wc -l < $far_dir/wavlist)
+ n2_text=$(wc -l < $far_dir/textgrid.flist)
+ log far file found $n1_wav wav and $n2_text text.
+
+ paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
+
+ cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+
+ python local/alimeeting_process_overlap_force.py --path $far_dir \
+ --no-overlap false --mars True \
+ --overlap_length 0.8 --max_length 7
+
+ cat $far_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $far_dir/text
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
+ #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
+
+ local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+ utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
+ sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
+ sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
+ sed -e 's/锛�//g' $far_dir/tmp2> $far_dir/tmp3
+ sed -e 's/锛�//g' $far_dir/tmp3> $far_dir/text
+fi
+
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ log "stage 2: finali data process"
+
+ local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+
+ sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+
+ # remove space in text
+ for x in ${tgt}_Ali_far; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ log "stage 3:process alimeeting far dir (single speaker by oracal time strap)"
+ cp -r $far_dir/* $far_single_speaker_dir
+ mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
+ paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
+ python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
+
+ cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
+ local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+
+ ./local/fix_data_dir.sh $far_single_speaker_dir
+ local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+
+ # remove space in text
+ for x in ${tgt}_Ali_far_single_speaker; do
+ cp data/${x}/text data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
+ > data/${x}/text
+ rm data/${x}/text.org
+ done
+ log "Successfully finished. [elapsed=${SECONDS}s]"
+fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
new file mode 100755
index 0000000..8ece757
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
@@ -0,0 +1,235 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+class Segment(object):
+ def __init__(self, uttid, spkr, stime, etime, text):
+ self.uttid = uttid
+ self.spkr = spkr
+ self.spkr_all = uttid+"-"+spkr
+ self.stime = round(stime, 2)
+ self.etime = round(etime, 2)
+ self.text = text
+ self.spk_text = {uttid+"-"+spkr: text}
+
+ def change_stime(self, time):
+ self.stime = time
+
+ def change_etime(self, time):
+ self.etime = time
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ parser.add_argument(
+ "--no-overlap",
+ type=strtobool,
+ default=False,
+ help="Whether to ignore the overlapping utterances.",
+ )
+ parser.add_argument(
+ "--max_length",
+ default=100000,
+ type=float,
+ help="overlap speech max time,if longger than max length should cut",
+ )
+ parser.add_argument(
+ "--overlap_length",
+ default=1,
+ type=float,
+ help="if length longer than max length, speech overlength shorter, is cut",
+ )
+ parser.add_argument(
+ "--mars",
+ type=strtobool,
+ default=False,
+ help="Whether to process mars data set.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def preposs_overlap(segments,max_length,overlap_length):
+ new_segments = []
+ # init a helper list to store all overlap segments
+ tmp_segments = segments[0]
+ min_stime = segments[0].stime
+ max_etime = segments[0].etime
+ overlap_length_big = 1.5
+ max_length_big = 15
+ for i in range(1, len(segments)):
+ if segments[i].stime >= max_etime:
+ # doesn't overlap with preivous segments
+ new_segments.append(tmp_segments)
+ tmp_segments = segments[i]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+ else:
+ # overlap with previous segments
+ dur_time = max_etime - min_stime
+ if dur_time < max_length:
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+ tmp_segments.stime = min_stime
+ tmp_segments.etime = max_etime
+ tmp_segments.text = tmp_segments.text + "src" + segments[i].text
+ spk_name =segments[i].uttid +"-" + segments[i].spkr
+ if spk_name in tmp_segments.spk_text:
+ tmp_segments.spk_text[spk_name] += segments[i].text
+ else:
+ tmp_segments.spk_text[spk_name] = segments[i].text
+ tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
+ else:
+ overlap_time = max_etime - segments[i].stime
+ if dur_time < max_length_big:
+ overlap_length_option = overlap_length
+ else:
+ overlap_length_option = overlap_length_big
+ if overlap_time > overlap_length_option:
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+ tmp_segments.stime = min_stime
+ tmp_segments.etime = max_etime
+ tmp_segments.text = tmp_segments.text + "src" + segments[i].text
+ spk_name =segments[i].uttid +"-" + segments[i].spkr
+ if spk_name in tmp_segments.spk_text:
+ tmp_segments.spk_text[spk_name] += segments[i].text
+ else:
+ tmp_segments.spk_text[spk_name] = segments[i].text
+ tmp_segments.spkr_all = tmp_segments.spkr_all + "src" + spk_name
+ else:
+ new_segments.append(tmp_segments)
+ tmp_segments = segments[i]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+
+ return new_segments
+
+def filter_overlap(segments):
+ new_segments = []
+ # init a helper list to store all overlap segments
+ tmp_segments = [segments[0]]
+ min_stime = segments[0].stime
+ max_etime = segments[0].etime
+
+ for i in range(1, len(segments)):
+ if segments[i].stime >= max_etime:
+ # doesn't overlap with preivous segments
+ if len(tmp_segments) == 1:
+ new_segments.append(tmp_segments[0])
+ # TODO: for multi-spkr asr, we can reset the stime/etime to
+ # min_stime/max_etime for generating a max length mixutre speech
+ tmp_segments = [segments[i]]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+ else:
+ # overlap with previous segments
+ tmp_segments.append(segments[i])
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+
+ return new_segments
+
+
+def main(args):
+ wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
+ textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+
+ # get the path of textgrid file for each utterance
+ utt2textgrid = {}
+ for line in textgrid_flist:
+ path = Path(line.strip())
+ uttid = path.stem
+ utt2textgrid[uttid] = path
+
+ # parse the textgrid file for each utterance
+ all_segments = []
+ for line in wav_scp:
+ uttid = line.strip().split(" ")[0]
+ uttid_part=uttid
+ if args.mars == True:
+ uttid_list = uttid.split("_")
+ uttid_part= uttid_list[0]+"_"+uttid_list[1]
+ if uttid_part not in utt2textgrid:
+ print("%s doesn't have transcription" % uttid)
+ continue
+
+ segments = []
+ tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
+ for i in range(tg.__len__()):
+ for j in range(tg[i].__len__()):
+ if tg[i][j].mark:
+ segments.append(
+ Segment(
+ uttid,
+ tg[i].name,
+ tg[i][j].minTime,
+ tg[i][j].maxTime,
+ tg[i][j].mark.strip(),
+ )
+ )
+
+ segments = sorted(segments, key=lambda x: x.stime)
+
+ if args.no_overlap:
+ segments = filter_overlap(segments)
+ else:
+ segments = preposs_overlap(segments,args.max_length,args.overlap_length)
+ all_segments += segments
+
+ wav_scp.close()
+ textgrid_flist.close()
+
+ segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
+ utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
+ text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
+ utt2spk_file_fifo = codecs.open(Path(args.path) / "utt2spk_all_fifo", "w", "utf-8")
+
+ for i in range(len(all_segments)):
+ utt_name = "%s-%s-%07d-%07d" % (
+ all_segments[i].uttid,
+ all_segments[i].spkr,
+ all_segments[i].stime * 100,
+ all_segments[i].etime * 100,
+ )
+
+ segments_file.write(
+ "%s %s %.2f %.2f\n"
+ % (
+ utt_name,
+ all_segments[i].uttid,
+ all_segments[i].stime,
+ all_segments[i].etime,
+ )
+ )
+ utt2spk_file.write(
+ "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
+ )
+ utt2spk_file_fifo.write(
+ "%s %s\n" % (utt_name, all_segments[i].spkr_all)
+ )
+ text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
+
+ segments_file.close()
+ utt2spk_file.close()
+ text_file.close()
+ utt2spk_file_fifo.close()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
new file mode 100755
index 0000000..81c1965
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+class Segment(object):
+ def __init__(self, uttid, spkr, stime, etime, text):
+ self.uttid = uttid
+ self.spkr = spkr
+ self.stime = round(stime, 2)
+ self.etime = round(etime, 2)
+ self.text = text
+
+ def change_stime(self, time):
+ self.stime = time
+
+ def change_etime(self, time):
+ self.etime = time
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ parser.add_argument(
+ "--no-overlap",
+ type=strtobool,
+ default=False,
+ help="Whether to ignore the overlapping utterances.",
+ )
+ parser.add_argument(
+ "--mars",
+ type=strtobool,
+ default=False,
+ help="Whether to process mars data set.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def filter_overlap(segments):
+ new_segments = []
+ # init a helper list to store all overlap segments
+ tmp_segments = [segments[0]]
+ min_stime = segments[0].stime
+ max_etime = segments[0].etime
+
+ for i in range(1, len(segments)):
+ if segments[i].stime >= max_etime:
+ # doesn't overlap with preivous segments
+ if len(tmp_segments) == 1:
+ new_segments.append(tmp_segments[0])
+ # TODO: for multi-spkr asr, we can reset the stime/etime to
+ # min_stime/max_etime for generating a max length mixutre speech
+ tmp_segments = [segments[i]]
+ min_stime = segments[i].stime
+ max_etime = segments[i].etime
+ else:
+ # overlap with previous segments
+ tmp_segments.append(segments[i])
+ if min_stime > segments[i].stime:
+ min_stime = segments[i].stime
+ if max_etime < segments[i].etime:
+ max_etime = segments[i].etime
+
+ return new_segments
+
+
+def main(args):
+ wav_scp = codecs.open(Path(args.path) / "wav.scp", "r", "utf-8")
+ textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+
+ # get the path of textgrid file for each utterance
+ utt2textgrid = {}
+ for line in textgrid_flist:
+ path = Path(line.strip())
+ uttid = path.stem
+ utt2textgrid[uttid] = path
+
+ # parse the textgrid file for each utterance
+ all_segments = []
+ for line in wav_scp:
+ uttid = line.strip().split(" ")[0]
+ uttid_part=uttid
+ if args.mars == True:
+ uttid_list = uttid.split("_")
+ uttid_part= uttid_list[0]+"_"+uttid_list[1]
+ if uttid_part not in utt2textgrid:
+ print("%s doesn't have transcription" % uttid)
+ continue
+ #pdb.set_trace()
+ segments = []
+ try:
+ tg = textgrid.TextGrid.fromFile(utt2textgrid[uttid_part])
+ except:
+ pdb.set_trace()
+ for i in range(tg.__len__()):
+ for j in range(tg[i].__len__()):
+ if tg[i][j].mark:
+ segments.append(
+ Segment(
+ uttid,
+ tg[i].name,
+ tg[i][j].minTime,
+ tg[i][j].maxTime,
+ tg[i][j].mark.strip(),
+ )
+ )
+
+ segments = sorted(segments, key=lambda x: x.stime)
+
+ if args.no_overlap:
+ segments = filter_overlap(segments)
+
+ all_segments += segments
+
+ wav_scp.close()
+ textgrid_flist.close()
+
+ segments_file = codecs.open(Path(args.path) / "segments_all", "w", "utf-8")
+ utt2spk_file = codecs.open(Path(args.path) / "utt2spk_all", "w", "utf-8")
+ text_file = codecs.open(Path(args.path) / "text_all", "w", "utf-8")
+
+ for i in range(len(all_segments)):
+ utt_name = "%s-%s-%07d-%07d" % (
+ all_segments[i].uttid,
+ all_segments[i].spkr,
+ all_segments[i].stime * 100,
+ all_segments[i].etime * 100,
+ )
+
+ segments_file.write(
+ "%s %s %.2f %.2f\n"
+ % (
+ utt_name,
+ all_segments[i].uttid,
+ all_segments[i].stime,
+ all_segments[i].etime,
+ )
+ )
+ utt2spk_file.write(
+ "%s %s-%s\n" % (utt_name, all_segments[i].uttid, all_segments[i].spkr)
+ )
+ text_file.write("%s %s\n" % (utt_name, all_segments[i].text))
+
+ segments_file.close()
+ utt2spk_file.close()
+ text_file.close()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/apply_map.pl b/egs/alimeeting/sa-asr/local/apply_map.pl
new file mode 100755
index 0000000..725d346
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/apply_map.pl
@@ -0,0 +1,97 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
+# Apache 2.0.
+
+# This program is a bit like ./sym2int.pl in that it applies a map
+# to things in a file, but it's a bit more general in that it doesn't
+# assume the things being mapped to are single tokens, they could
+# be sequences of tokens. See the usage message.
+
+
+$permissive = 0;
+
+for ($x = 0; $x <= 2; $x++) {
+
+ if (@ARGV > 0 && $ARGV[0] eq "-f") {
+ shift @ARGV;
+ $field_spec = shift @ARGV;
+ if ($field_spec =~ m/^\d+$/) {
+ $field_begin = $field_spec - 1; $field_end = $field_spec - 1;
+ }
+ if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
+ if ($1 ne "") {
+ $field_begin = $1 - 1; # Change to zero-based indexing.
+ }
+ if ($2 ne "") {
+ $field_end = $2 - 1; # Change to zero-based indexing.
+ }
+ }
+ if (!defined $field_begin && !defined $field_end) {
+ die "Bad argument to -f option: $field_spec";
+ }
+ }
+
+ if (@ARGV > 0 && $ARGV[0] eq '--permissive') {
+ shift @ARGV;
+ # Mapping is optional (missing key is printed to output)
+ $permissive = 1;
+ }
+}
+
+if(@ARGV != 1) {
+ print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n";
+ print STDERR <<'EOF';
+Usage: apply_map.pl [options] map <input >output
+ options: [-f <field-range> ] [--permissive]
+ This applies a map to some specified fields of some input text:
+ For each line in the map file: the first field is the thing we
+ map from, and the remaining fields are the sequence we map it to.
+ The -f (field-range) option says which fields of the input file the map
+ map should apply to.
+ If the --permissive option is supplied, fields which are not present
+ in the map will be left as they were.
+ Applies the map 'map' to all input text, where each line of the map
+ is interpreted as a map from the first field to the list of the other fields
+ Note: <field-range> can look like 4-5, or 4-, or 5-, or 1, it means the field
+ range in the input to apply the map to.
+ e.g.: echo A B | apply_map.pl a.txt
+ where a.txt is:
+ A a1 a2
+ B b
+ will produce:
+ a1 a2 b
+EOF
+ exit(1);
+}
+
+($map_file) = @ARGV;
+open(M, "<$map_file") || die "Error opening map file $map_file: $!";
+
+while (<M>) {
+ @A = split(" ", $_);
+ @A >= 1 || die "apply_map.pl: empty line.";
+ $i = shift @A;
+ $o = join(" ", @A);
+ $map{$i} = $o;
+}
+
+while(<STDIN>) {
+ @A = split(" ", $_);
+ for ($x = 0; $x < @A; $x++) {
+ if ( (!defined $field_begin || $x >= $field_begin)
+ && (!defined $field_end || $x <= $field_end)) {
+ $a = $A[$x];
+ if (!defined $map{$a}) {
+ if (!$permissive) {
+ die "apply_map.pl: undefined key $a in $map_file\n";
+ } else {
+ print STDERR "apply_map.pl: warning! missing key $a in $map_file\n";
+ }
+ } else {
+ $A[$x] = $map{$a};
+ }
+ }
+ }
+ print join(" ", @A) . "\n";
+}
diff --git a/egs/alimeeting/sa-asr/local/combine_data.sh b/egs/alimeeting/sa-asr/local/combine_data.sh
new file mode 100755
index 0000000..a3436b5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/combine_data.sh
@@ -0,0 +1,146 @@
+#!/usr/bin/env bash
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
+# 2014 David Snyder
+
+# This script combines the data from multiple source directories into
+# a single destination directory.
+
+# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
+# about what these directories contain.
+
+# Begin configuration section.
+extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
+skip_fix=false # skip the fix_data_dir.sh in the end
+# End configuration section.
+
+echo "$0 $@" # Print the command line for logging
+
+if [ -f path.sh ]; then . ./path.sh; fi
+. parse_options.sh || exit 1;
+
+if [ $# -lt 2 ]; then
+ echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
+ echo "Note, files that don't appear in all source dirs will not be combined,"
+ echo "with the exception of utt2uniq and segments, which are created where necessary."
+ exit 1
+fi
+
+dest=$1;
+shift;
+
+first_src=$1;
+
+rm -r $dest 2>/dev/null || true
+mkdir -p $dest;
+
+export LC_ALL=C
+
+for dir in $*; do
+ if [ ! -f $dir/utt2spk ]; then
+ echo "$0: no such file $dir/utt2spk"
+ exit 1;
+ fi
+done
+
+# Check that frame_shift are compatible, where present together with features.
+dir_with_frame_shift=
+for dir in $*; do
+ if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
+ if [[ $dir_with_frame_shift ]] &&
+ ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
+ echo "$0:error: different frame_shift in directories $dir and " \
+ "$dir_with_frame_shift. Cannot combine features."
+ exit 1;
+ fi
+ dir_with_frame_shift=$dir
+ fi
+done
+
+# W.r.t. utt2uniq file the script has different behavior compared to other files
+# it is not compulsary for it to exist in src directories, but if it exists in
+# even one it should exist in all. We will create the files where necessary
+has_utt2uniq=false
+for in_dir in $*; do
+ if [ -f $in_dir/utt2uniq ]; then
+ has_utt2uniq=true
+ break
+ fi
+done
+
+if $has_utt2uniq; then
+ # we are going to create an utt2uniq file in the destdir
+ for in_dir in $*; do
+ if [ ! -f $in_dir/utt2uniq ]; then
+ # we assume that utt2uniq is a one to one mapping
+ cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
+ else
+ cat $in_dir/utt2uniq
+ fi
+ done | sort -k1 > $dest/utt2uniq
+ echo "$0: combined utt2uniq"
+else
+ echo "$0 [info]: not combining utt2uniq as it does not exist"
+fi
+# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
+extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
+
+# segments are treated similarly to utt2uniq. If it exists in some, but not all
+# src directories, then we generate segments where necessary.
+has_segments=false
+for in_dir in $*; do
+ if [ -f $in_dir/segments ]; then
+ has_segments=true
+ break
+ fi
+done
+
+if $has_segments; then
+ for in_dir in $*; do
+ if [ ! -f $in_dir/segments ]; then
+ echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
+ local/data/get_segments_for_data.sh $in_dir
+ else
+ cat $in_dir/segments
+ fi
+ done | sort -k1 > $dest/segments
+ echo "$0: combined segments"
+else
+ echo "$0 [info]: not combining segments as it does not exist"
+fi
+
+for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
+ exists_somewhere=false
+ absent_somewhere=false
+ for d in $*; do
+ if [ -f $d/$file ]; then
+ exists_somewhere=true
+ else
+ absent_somewhere=true
+ fi
+ done
+
+ if ! $absent_somewhere; then
+ set -o pipefail
+ ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
+ set +o pipefail
+ echo "$0: combined $file"
+ else
+ if ! $exists_somewhere; then
+ echo "$0 [info]: not combining $file as it does not exist"
+ else
+ echo "$0 [info]: **not combining $file as it does not exist everywhere**"
+ fi
+ fi
+done
+
+local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
+
+if [[ $dir_with_frame_shift ]]; then
+ cp $dir_with_frame_shift/frame_shift $dest
+fi
+
+if ! $skip_fix ; then
+ local/fix_data_dir.sh $dest || exit 1;
+fi
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa-asr/local/compute_cpcer.py
new file mode 100644
index 0000000..f4d4a79
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/compute_cpcer.py
@@ -0,0 +1,91 @@
+import editdistance
+import sys
+import os
+from itertools import permutations
+
+
+def load_transcripts(file_path):
+ trans_list = []
+ for one_line in open(file_path, "rt"):
+ meeting_id, trans = one_line.strip().split(" ")
+ trans_list.append((meeting_id.strip(), trans.strip()))
+
+ return trans_list
+
+def calc_spk_trans(trans):
+ spk_trans_ = [x.strip() for x in trans.split("$")]
+ spk_trans = []
+ for i in range(len(spk_trans_)):
+ spk_trans.append((str(i), spk_trans_[i]))
+ return spk_trans
+
+def calc_cer(ref_trans, hyp_trans):
+ ref_spk_trans = calc_spk_trans(ref_trans)
+ hyp_spk_trans = calc_spk_trans(hyp_trans)
+ ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
+ num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
+ ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
+ hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
+
+ errors, counts, permutes = [], [], []
+ min_error = 0
+ cost_dict = {}
+ for perm in permutations(range(num_spk)):
+ flag = True
+ p_err, p_count = 0, 0
+ for idx, p in enumerate(perm):
+ if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
+ flag = False
+ break
+ cost_key = "{}-{}".format(idx, p)
+ if cost_key in cost_dict:
+ _e = cost_dict[cost_key]
+ else:
+ _e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
+ cost_dict[cost_key] = _e
+ if _e > min_error > 0:
+ flag = False
+ break
+ p_err += _e
+ p_count += len(ref_spk_trans[idx][1])
+
+ if flag:
+ if p_err < min_error or min_error == 0:
+ min_error = p_err
+
+ errors.append(p_err)
+ counts.append(p_count)
+ permutes.append(perm)
+
+ sd_cer = [(err, cnt, err/cnt, permute)
+ for err, cnt, permute in zip(errors, counts, permutes)]
+ # import ipdb;ipdb.set_trace()
+ best_rst = min(sd_cer, key=lambda x: x[2])
+
+ return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
+
+
+def main():
+ ref=sys.argv[1]
+ hyp=sys.argv[2]
+ result_path=sys.argv[3]
+ ref_list = load_transcripts(ref)
+ hyp_list = load_transcripts(hyp)
+ result_file = open(result_path,'w')
+ error, count = 0, 0
+ for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
+ assert ref_id == hyp_id
+ mid = ref_id
+ dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
+ error, count = error + dist, count + length
+ result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
+
+ # print("{} {:.2f} {} {}".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
+
+ result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
+ result_file.close()
+ # print("Sum/Avg: {:.2f}".format(error / count * 100.0))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/alimeeting/sa-asr/local/copy_data_dir.sh b/egs/alimeeting/sa-asr/local/copy_data_dir.sh
new file mode 100755
index 0000000..6e748dd
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/copy_data_dir.sh
@@ -0,0 +1,145 @@
+#!/usr/bin/env bash
+
+# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+# feats.scp
+# wav.scp
+# vad.scp
+# spk2utt
+# utt2spk
+# text
+#
+# It copies to another directory, possibly adding a specified prefix or a suffix
+# to the utterance and/or speaker names. Note, the recording-ids stay the same.
+#
+
+
+# begin configuration section
+spk_prefix=
+utt_prefix=
+spk_suffix=
+utt_suffix=
+validate_opts= # should rarely be needed.
+# end configuration section
+
+. utils/parse_options.sh
+
+if [ $# != 2 ]; then
+ echo "Usage: "
+ echo " $0 [options] <srcdir> <destdir>"
+ echo "e.g.:"
+ echo " $0 --spk-prefix=1- --utt-prefix=1- data/train data/train_1"
+ echo "Options"
+ echo " --spk-prefix=<prefix> # Prefix for speaker ids, default empty"
+ echo " --utt-prefix=<prefix> # Prefix for utterance ids, default empty"
+ echo " --spk-suffix=<suffix> # Suffix for speaker ids, default empty"
+ echo " --utt-suffix=<suffix> # Suffix for utterance ids, default empty"
+ exit 1;
+fi
+
+
+export LC_ALL=C
+
+srcdir=$1
+destdir=$2
+
+if [ ! -f $srcdir/utt2spk ]; then
+ echo "copy_data_dir.sh: no such file $srcdir/utt2spk"
+ exit 1;
+fi
+
+if [ "$destdir" == "$srcdir" ]; then
+ echo "$0: this script requires <srcdir> and <destdir> to be different."
+ exit 1
+fi
+
+set -e;
+
+mkdir -p $destdir
+
+cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/utt_map
+cat $srcdir/spk2utt | awk -v p=$spk_prefix -v s=$spk_suffix '{printf("%s %s%s%s\n", $1, p, $1, s);}' > $destdir/spk_map
+
+if [ ! -f $srcdir/utt2uniq ]; then
+ if [[ ! -z $utt_prefix || ! -z $utt_suffix ]]; then
+ cat $srcdir/utt2spk | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $1);}' > $destdir/utt2uniq
+ fi
+else
+ cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
+fi
+
+cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \
+ local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
+
+local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
+
+if [ -f $srcdir/feats.scp ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
+fi
+
+if [ -f $srcdir/vad.scp ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
+fi
+
+if [ -f $srcdir/segments ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
+ cp $srcdir/wav.scp $destdir
+else # no segments->wav indexed by utt.
+ if [ -f $srcdir/wav.scp ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
+ fi
+fi
+
+if [ -f $srcdir/reco2file_and_channel ]; then
+ cp $srcdir/reco2file_and_channel $destdir/
+fi
+
+if [ -f $srcdir/text ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
+fi
+if [ -f $srcdir/utt2dur ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
+fi
+if [ -f $srcdir/utt2num_frames ]; then
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
+fi
+if [ -f $srcdir/reco2dur ]; then
+ if [ -f $srcdir/segments ]; then
+ cp $srcdir/reco2dur $destdir/reco2dur
+ else
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
+ fi
+fi
+if [ -f $srcdir/spk2gender ]; then
+ local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
+fi
+if [ -f $srcdir/cmvn.scp ]; then
+ local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
+fi
+for f in frame_shift stm glm ctm; do
+ if [ -f $srcdir/$f ]; then
+ cp $srcdir/$f $destdir
+ fi
+done
+
+rm $destdir/spk_map $destdir/utt_map
+
+echo "$0: copied data from $srcdir to $destdir"
+
+for f in feats.scp cmvn.scp vad.scp utt2lang utt2uniq utt2dur utt2num_frames text wav.scp reco2file_and_channel frame_shift stm glm ctm; do
+ if [ -f $destdir/$f ] && [ ! -f $srcdir/$f ]; then
+ echo "$0: file $f exists in dest $destdir but not in src $srcdir. Moving it to"
+ echo " ... $destdir/.backup/$f"
+ mkdir -p $destdir/.backup
+ mv $destdir/$f $destdir/.backup/
+ fi
+done
+
+
+[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
+[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
+
+local/validate_data_dir.sh $validate_opts $destdir
diff --git a/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
new file mode 100755
index 0000000..24f51e7
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
@@ -0,0 +1,143 @@
+#!/usr/bin/env bash
+
+# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
+# 2018 Andrea Carmantini
+# Apache 2.0
+
+# This script operates on a data directory, such as in data/train/, and adds the
+# reco2dur file if it does not already exist. The file 'reco2dur' maps from
+# recording to the duration of the recording in seconds. This script works it
+# out from the 'wav.scp' file, or, if utterance-ids are the same as recording-ids, from the
+# utt2dur file (it first tries interrogating the headers, and if this fails, it reads the wave
+# files in entirely.)
+# We could use durations from segments file, but that's not the duration of the recordings
+# but the sum of utterance lenghts (silence in between could be excluded from segments)
+# For sum of utterance lenghts:
+# awk 'FNR==NR{uttdur[$1]=$2;next}
+# { for(i=2;i<=NF;i++){dur+=uttdur[$i];}
+# print $1 FS dur; dur=0 }' $data/utt2dur $data/reco2utt
+
+
+frame_shift=0.01
+cmd=run.pl
+nj=4
+
+. utils/parse_options.sh
+. ./path.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: $0 [options] <datadir>"
+ echo "e.g.:"
+ echo " $0 data/train"
+ echo " Options:"
+ echo " --frame-shift # frame shift in seconds. Only relevant when we are"
+ echo " # getting duration from feats.scp (default: 0.01). "
+ exit 1
+fi
+
+export LC_ALL=C
+
+data=$1
+
+
+if [ -s $data/reco2dur ] && \
+ [ $(wc -l < $data/wav.scp) -eq $(wc -l < $data/reco2dur) ]; then
+ echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it."
+ exit 0;
+fi
+
+if [ -s $data/utt2dur ] && \
+ [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ] && \
+ [ ! -s $data/segments ]; then
+
+ echo "$0: $data/wav.scp indexed by utt-id; copying utt2dur to reco2dur"
+ cp $data/utt2dur $data/reco2dur && exit 0;
+
+elif [ -f $data/wav.scp ]; then
+ echo "$0: obtaining durations from recordings"
+
+ # if the wav.scp contains only lines of the form
+ # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
+ if cat $data/wav.scp | perl -e '
+ while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
+ @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
+ $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
+ $reco = $A[0]; $sphere_file = $A[4];
+
+ if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
+ $sample_rate = -1; $sample_count = -1;
+ for ($n = 0; $n <= 30; $n++) {
+ $line = <F>;
+ if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
+ if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
+ if ($line =~ m/end_head/) { break; }
+ }
+ close(F);
+ if ($sample_rate == -1 || $sample_count == -1) {
+ die "could not parse sphere header from $sphere_file";
+ }
+ $duration = $sample_count * 1.0 / $sample_rate;
+ print "$reco $duration\n";
+ } ' > $data/reco2dur; then
+ echo "$0: successfully obtained recording lengths from sphere-file headers"
+ else
+ echo "$0: could not get recording lengths from sphere-file headers, using wav-to-duration"
+ if ! command -v wav-to-duration >/dev/null; then
+ echo "$0: wav-to-duration is not on your path"
+ exit 1;
+ fi
+
+ read_entire_file=false
+ if grep -q 'sox.*speed' $data/wav.scp; then
+ read_entire_file=true
+ echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
+ echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
+ echo "... perturb_data_dir_speed_3way.sh."
+ fi
+
+ num_recos=$(wc -l <$data/wav.scp)
+ if [ $nj -gt $num_recos ]; then
+ nj=$num_recos
+ fi
+
+ temp_data_dir=$data/wav${nj}split
+ wavscps=$(for n in `seq $nj`; do echo $temp_data_dir/$n/wav.scp; done)
+ subdirs=$(for n in `seq $nj`; do echo $temp_data_dir/$n; done)
+
+ if ! mkdir -p $subdirs >&/dev/null; then
+ for n in `seq $nj`; do
+ mkdir -p $temp_data_dir/$n
+ done
+ fi
+
+ utils/split_scp.pl $data/wav.scp $wavscps
+
+
+ $cmd JOB=1:$nj $data/log/get_reco_durations.JOB.log \
+ wav-to-duration --read-entire-file=$read_entire_file \
+ scp:$temp_data_dir/JOB/wav.scp ark,t:$temp_data_dir/JOB/reco2dur || \
+ { echo "$0: there was a problem getting the durations"; exit 1; } # This could
+
+ for n in `seq $nj`; do
+ cat $temp_data_dir/$n/reco2dur
+ done > $data/reco2dur
+ fi
+ rm -r $temp_data_dir
+else
+ echo "$0: Expected $data/wav.scp to exist"
+ exit 1
+fi
+
+len1=$(wc -l < $data/wav.scp)
+len2=$(wc -l < $data/reco2dur)
+if [ "$len1" != "$len2" ]; then
+ echo "$0: warning: length of reco2dur does not equal that of wav.scp, $len2 != $len1"
+ if [ $len1 -gt $[$len2*2] ]; then
+ echo "$0: less than half of recordings got a duration: failing."
+ exit 1
+ fi
+fi
+
+echo "$0: computed $data/reco2dur"
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
new file mode 100755
index 0000000..9310715
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+
+# This script operates on a data directory, such as in data/train/,
+# and writes new segments to stdout. The file 'segments' maps from
+# utterance to time offsets into a recording, with the format:
+# <utterance-id> <recording-id> <segment-begin> <segment-end>
+# This script assumes utterance and recording ids are the same (i.e., that
+# wav.scp is indexed by utterance), and uses durations from 'utt2dur',
+# created if necessary by get_utt2dur.sh.
+
+. ./path.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: $0 [options] <datadir>"
+ echo "e.g.:"
+ echo " $0 data/train > data/train/segments"
+ exit 1
+fi
+
+data=$1
+
+if [ ! -s $data/utt2dur ]; then
+ local/data/get_utt2dur.sh $data 1>&2 || exit 1;
+fi
+
+# <utt-id> <utt-id> 0 <utt-dur>
+awk '{ print $1, $1, 0, $2 }' $data/utt2dur
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
new file mode 100755
index 0000000..833a7fc
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
@@ -0,0 +1,135 @@
+#!/usr/bin/env bash
+
+# Copyright 2016 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+# This script operates on a data directory, such as in data/train/, and adds the
+# utt2dur file if it does not already exist. The file 'utt2dur' maps from
+# utterance to the duration of the utterance in seconds. This script works it
+# out from the 'segments' file, or, if not present, from the wav.scp file (it
+# first tries interrogating the headers, and if this fails, it reads the wave
+# files in entirely.)
+
+frame_shift=0.01
+cmd=run.pl
+nj=4
+read_entire_file=false
+
+. utils/parse_options.sh
+. ./path.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: $0 [options] <datadir>"
+ echo "e.g.:"
+ echo " $0 data/train"
+ echo " Options:"
+ echo " --frame-shift # frame shift in seconds. Only relevant when we are"
+ echo " # getting duration from feats.scp, and only if the "
+ echo " # file frame_shift does not exist (default: 0.01). "
+ exit 1
+fi
+
+export LC_ALL=C
+
+data=$1
+
+if [ -s $data/utt2dur ] && \
+ [ $(wc -l < $data/utt2spk) -eq $(wc -l < $data/utt2dur) ]; then
+ echo "$0: $data/utt2dur already exists with the expected length. We won't recompute it."
+ exit 0;
+fi
+
+if [ -s $data/segments ]; then
+ echo "$0: working out $data/utt2dur from $data/segments"
+ awk '{len=$4-$3; print $1, len;}' < $data/segments > $data/utt2dur
+elif [[ -s $data/frame_shift && -f $data/utt2num_frames ]]; then
+ echo "$0: computing $data/utt2dur from $data/{frame_shift,utt2num_frames}."
+ frame_shift=$(cat $data/frame_shift) || exit 1
+ # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
+ awk -v fs=$frame_shift '{ $2=($2+1.5)*fs; print }' <$data/utt2num_frames >$data/utt2dur
+elif [ -f $data/wav.scp ]; then
+ echo "$0: segments file does not exist so getting durations from wave files"
+
+ # if the wav.scp contains only lines of the form
+ # utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph |
+ if perl <$data/wav.scp -e '
+ while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space.
+ @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ &&
+ $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); }
+ $utt = $A[0]; $sphere_file = $A[4];
+
+ if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; }
+ $sample_rate = -1; $sample_count = -1;
+ for ($n = 0; $n <= 30; $n++) {
+ $line = <F>;
+ if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; }
+ if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; }
+ if ($line =~ m/end_head/) { break; }
+ }
+ close(F);
+ if ($sample_rate == -1 || $sample_count == -1) {
+ die "could not parse sphere header from $sphere_file";
+ }
+ $duration = $sample_count * 1.0 / $sample_rate;
+ print "$utt $duration\n";
+ } ' > $data/utt2dur; then
+ echo "$0: successfully obtained utterance lengths from sphere-file headers"
+ else
+ echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration"
+ if ! command -v wav-to-duration >/dev/null; then
+ echo "$0: wav-to-duration is not on your path"
+ exit 1;
+ fi
+
+ if grep -q 'sox.*speed' $data/wav.scp; then
+ read_entire_file=true
+ echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow."
+ echo "... It is much faster if you call get_utt2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or "
+ echo "... perturb_data_dir_speed_3way.sh."
+ fi
+
+
+ num_utts=$(wc -l <$data/utt2spk)
+ if [ $nj -gt $num_utts ]; then
+ nj=$num_utts
+ fi
+
+ local/data/split_data.sh --per-utt $data $nj
+ sdata=$data/split${nj}utt
+
+ $cmd JOB=1:$nj $data/log/get_durations.JOB.log \
+ wav-to-duration --read-entire-file=$read_entire_file \
+ scp:$sdata/JOB/wav.scp ark,t:$sdata/JOB/utt2dur || \
+ { echo "$0: there was a problem getting the durations"; exit 1; }
+
+ for n in `seq $nj`; do
+ cat $sdata/$n/utt2dur
+ done > $data/utt2dur
+ fi
+elif [ -f $data/feats.scp ]; then
+ echo "$0: wave file does not exist so getting durations from feats files"
+ if [[ -s $data/frame_shift ]]; then
+ frame_shift=$(cat $data/frame_shift) || exit 1
+ echo "$0: using frame_shift=$frame_shift from file $data/frame_shift"
+ fi
+ # The 1.5 correction is the typical value of (frame_length-frame_shift)/frame_shift.
+ feat-to-len scp:$data/feats.scp ark,t:- |
+ awk -v frame_shift=$frame_shift '{print $1, ($2+1.5)*frame_shift}' >$data/utt2dur
+else
+ echo "$0: Expected $data/wav.scp, $data/segments or $data/feats.scp to exist"
+ exit 1
+fi
+
+len1=$(wc -l < $data/utt2spk)
+len2=$(wc -l < $data/utt2dur)
+if [ "$len1" != "$len2" ]; then
+ echo "$0: warning: length of utt2dur does not equal that of utt2spk, $len2 != $len1"
+ if [ $len1 -gt $[$len2*2] ]; then
+ echo "$0: less than half of utterances got a duration: failing."
+ exit 1
+ fi
+fi
+
+echo "$0: computed $data/utt2dur"
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/local/data/split_data.sh b/egs/alimeeting/sa-asr/local/data/split_data.sh
new file mode 100755
index 0000000..97ad8c5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/data/split_data.sh
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+# Copyright 2010-2013 Microsoft Corporation
+# Johns Hopkins University (Author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+split_per_spk=true
+if [ "$1" == "--per-utt" ]; then
+ split_per_spk=false
+ shift
+fi
+
+if [ $# != 2 ]; then
+ echo "Usage: $0 [--per-utt] <data-dir> <num-to-split>"
+ echo "E.g.: $0 data/train 50"
+ echo "It creates its output in e.g. data/train/split50/{1,2,3,...50}, or if the "
+ echo "--per-utt option was given, in e.g. data/train/split50utt/{1,2,3,...50}."
+ echo ""
+ echo "This script will not split the data-dir if it detects that the output is newer than the input."
+ echo "By default it splits per speaker (so each speaker is in only one split dir),"
+ echo "but with the --per-utt option it will ignore the speaker information while splitting."
+ exit 1
+fi
+
+data=$1
+numsplit=$2
+
+if ! [ "$numsplit" -gt 0 ]; then
+ echo "Invalid num-split argument $numsplit";
+ exit 1;
+fi
+
+if $split_per_spk; then
+ warning_opt=
+else
+ # suppress warnings from filter_scps.pl about 'some input lines were output
+ # to multiple files'.
+ warning_opt="--no-warn"
+fi
+
+n=0;
+feats=""
+wavs=""
+utt2spks=""
+texts=""
+
+nu=`cat $data/utt2spk | wc -l`
+nf=`cat $data/feats.scp 2>/dev/null | wc -l`
+nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
+if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
+ echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
+ echo "** use local/fix_data_dir.sh $data to fix this."
+fi
+if [ -f $data/text ] && [ $nu -ne $nt ]; then
+ echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
+ echo "** use local/fix_data_dir.sh to fix this."
+fi
+
+
+if $split_per_spk; then
+ utt2spk_opt="--utt2spk=$data/utt2spk"
+ utt=""
+else
+ utt2spk_opt=
+ utt="utt"
+fi
+
+s1=$data/split${numsplit}${utt}/1
+if [ ! -d $s1 ]; then
+ need_to_split=true
+else
+ need_to_split=false
+ for f in utt2spk spk2utt spk2warp feats.scp text wav.scp cmvn.scp spk2gender \
+ vad.scp segments reco2file_and_channel utt2lang; do
+ if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
+ need_to_split=true
+ fi
+ done
+fi
+
+if ! $need_to_split; then
+ exit 0;
+fi
+
+utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done)
+
+directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done)
+
+# if this mkdir fails due to argument-list being too long, iterate.
+if ! mkdir -p $directories >&/dev/null; then
+ for n in `seq $numsplit`; do
+ mkdir -p $data/split${numsplit}${utt}/$n
+ done
+fi
+
+# If lockfile is not installed, just don't lock it. It's not a big deal.
+which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
+trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM
+
+utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
+
+for n in `seq $numsplit`; do
+ dsn=$data/split${numsplit}${utt}/$n
+ local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
+done
+
+maybe_wav_scp=
+if [ ! -f $data/segments ]; then
+ maybe_wav_scp=wav.scp # If there is no segments file, then wav file is
+ # indexed per utt.
+fi
+
+# split some things that are indexed by utterance.
+for f in feats.scp text vad.scp utt2lang $maybe_wav_scp utt2dur utt2num_frames; do
+ if [ -f $data/$f ]; then
+ utils/filter_scps.pl JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/utt2spk $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
+ fi
+done
+
+# split some things that are indexed by speaker
+for f in spk2gender spk2warp cmvn.scp; do
+ if [ -f $data/$f ]; then
+ utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/spk2utt $data/$f $data/split${numsplit}${utt}/JOB/$f || exit 1;
+ fi
+done
+
+if [ -f $data/segments ]; then
+ utils/filter_scps.pl JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/utt2spk $data/segments $data/split${numsplit}${utt}/JOB/segments || exit 1
+ for n in `seq $numsplit`; do
+ dsn=$data/split${numsplit}${utt}/$n
+ awk '{print $2;}' $dsn/segments | sort | uniq > $dsn/tmp.reco # recording-ids.
+ done
+ if [ -f $data/reco2file_and_channel ]; then
+ utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/tmp.reco $data/reco2file_and_channel \
+ $data/split${numsplit}${utt}/JOB/reco2file_and_channel || exit 1
+ fi
+ if [ -f $data/wav.scp ]; then
+ utils/filter_scps.pl $warning_opt JOB=1:$numsplit \
+ $data/split${numsplit}${utt}/JOB/tmp.reco $data/wav.scp \
+ $data/split${numsplit}${utt}/JOB/wav.scp || exit 1
+ fi
+ for f in $data/split${numsplit}${utt}/*/tmp.reco; do rm $f; done
+fi
+
+exit 0
diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa-asr/local/download_xvector_model.py
new file mode 100644
index 0000000..7da6559
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/download_xvector_model.py
@@ -0,0 +1,6 @@
+from modelscope.hub.snapshot_download import snapshot_download
+import sys
+
+
+cache_dir = sys.argv[1]
+model_dir = snapshot_download('damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch', cache_dir=cache_dir)
diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
new file mode 100644
index 0000000..e606162
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
@@ -0,0 +1,22 @@
+import sys
+if __name__=="__main__":
+ uttid_path=sys.argv[1]
+ src_path=sys.argv[2]
+ tgt_path=sys.argv[3]
+ uttid_file=open(uttid_path,'r')
+ uttid_line=uttid_file.readlines()
+ uttid_file.close()
+ ori_utt2spk_all_fifo_file=open(src_path+'/utt2spk_all_fifo','r')
+ ori_utt2spk_all_fifo_line=ori_utt2spk_all_fifo_file.readlines()
+ ori_utt2spk_all_fifo_file.close()
+ new_utt2spk_all_fifo_file=open(tgt_path+'/utt2spk_all_fifo','w')
+
+ uttid_list=[]
+ for line in uttid_line:
+ uttid_list.append(line.strip())
+
+ for line in ori_utt2spk_all_fifo_line:
+ if line.strip().split(' ')[0] in uttid_list:
+ new_utt2spk_all_fifo_file.write(line)
+
+ new_utt2spk_all_fifo_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/fix_data_dir.sh b/egs/alimeeting/sa-asr/local/fix_data_dir.sh
new file mode 100755
index 0000000..3abd465
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/fix_data_dir.sh
@@ -0,0 +1,215 @@
+#!/usr/bin/env bash
+
+# This script makes sure that only the segments present in
+# all of "feats.scp", "wav.scp" [if present], segments [if present]
+# text, and utt2spk are present in any of them.
+# It puts the original contents of data-dir into
+# data-dir/.backup
+
+cmd="$@"
+
+utt_extra_files=
+spk_extra_files=
+
+. utils/parse_options.sh
+
+if [ $# != 1 ]; then
+ echo "Usage: utils/data/fix_data_dir.sh <data-dir>"
+ echo "e.g.: utils/data/fix_data_dir.sh data/train"
+ echo "This script helps ensure that the various files in a data directory"
+ echo "are correctly sorted and filtered, for example removing utterances"
+ echo "that have no features (if feats.scp is present)"
+ exit 1
+fi
+
+data=$1
+
+if [ -f $data/images.scp ]; then
+ image/fix_data_dir.sh $cmd
+ exit $?
+fi
+
+mkdir -p $data/.backup
+
+[ ! -d $data ] && echo "$0: no such directory $data" && exit 1;
+
+[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1;
+
+set -e -o pipefail -u
+
+tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
+trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
+
+export LC_ALL=C
+
+function check_sorted {
+ file=$1
+ sort -k1,1 -u <$file >$file.tmp
+ if ! cmp -s $file $file.tmp; then
+ echo "$0: file $1 is not in sorted order or not unique, sorting it"
+ mv $file.tmp $file
+ else
+ rm $file.tmp
+ fi
+}
+
+for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \
+ reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do
+ if [ -f $data/$x ]; then
+ cp $data/$x $data/.backup/$x
+ check_sorted $data/$x
+ fi
+done
+
+
+function filter_file {
+ filter=$1
+ file_to_filter=$2
+ cp $file_to_filter ${file_to_filter}.tmp
+ utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter
+ if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then
+ length1=$(cat ${file_to_filter}.tmp | wc -l)
+ length2=$(cat ${file_to_filter} | wc -l)
+ if [ $length1 -ne $length2 ]; then
+ echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter."
+ fi
+ fi
+ rm $file_to_filter.tmp
+}
+
+function filter_recordings {
+ # We call this once before the stage when we filter on utterance-id, and once
+ # after.
+
+ if [ -f $data/segments ]; then
+ # We have a segments file -> we need to filter this and the file wav.scp, and
+ # reco2file_and_utt, if it exists, to make sure they have the same list of
+ # recording-ids.
+
+ if [ ! -f $data/wav.scp ]; then
+ echo "$0: $data/segments exists but not $data/wav.scp"
+ exit 1;
+ fi
+ awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings
+ n1=$(cat $tmpdir/recordings | wc -l)
+ [ ! -s $tmpdir/recordings ] && \
+ echo "Empty list of recordings (bad file $data/segments)?" && exit 1;
+ utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp
+ mv $tmpdir/recordings.tmp $tmpdir/recordings
+
+
+ cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
+ filter_file $tmpdir/recordings $data/segments
+ cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments
+ rm $data/segments.tmp
+
+ filter_file $tmpdir/recordings $data/wav.scp
+ [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel
+ [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur
+ true
+ fi
+}
+
+function filter_speakers {
+ # throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
+ local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ for s in cmvn.scp spk2gender; do
+ f=$data/$s
+ if [ -f $f ]; then
+ filter_file $f $tmpdir/speakers
+ fi
+ done
+
+ filter_file $tmpdir/speakers $data/spk2utt
+ local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
+
+ for s in cmvn.scp spk2gender $spk_extra_files; do
+ f=$data/$s
+ if [ -f $f ]; then
+ filter_file $tmpdir/speakers $f
+ fi
+ done
+}
+
+function filter_utts {
+ cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
+
+ ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \
+ echo "utt2spk is not in sorted order (fix this yourself)" && exit 1;
+
+ ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \
+ echo "utt2spk is not in sorted order when sorted first on speaker-id " && \
+ echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
+
+ ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \
+ echo "spk2utt is not in sorted order (fix this yourself)" && exit 1;
+
+ if [ -f $data/utt2uniq ]; then
+ ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \
+ echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1;
+ fi
+
+ maybe_wav=
+ maybe_reco2dur=
+ [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist.
+ [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts
+
+ maybe_utt2dur=
+ if [ -f $data/utt2dur ]; then
+ cat $data/utt2dur | \
+ awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2dur.ok || exit 1
+ maybe_utt2dur=utt2dur.ok
+ fi
+
+ maybe_utt2num_frames=
+ if [ -f $data/utt2num_frames ]; then
+ cat $data/utt2num_frames | \
+ awk '{ if (NF == 2 && $2 > 0) { print }}' > $data/utt2num_frames.ok || exit 1
+ maybe_utt2num_frames=utt2num_frames.ok
+ fi
+
+ for x in feats.scp text segments utt2lang $maybe_wav $maybe_utt2dur $maybe_utt2num_frames; do
+ if [ -f $data/$x ]; then
+ utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp
+ mv $tmpdir/utts.tmp $tmpdir/utts
+ fi
+ done
+ rm $data/utt2dur.ok 2>/dev/null || true
+ rm $data/utt2num_frames.ok 2>/dev/null || true
+
+ [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \
+ rm $tmpdir/utts && exit 1;
+
+
+ if [ -f $data/utt2spk ]; then
+ new_nutts=$(cat $tmpdir/utts | wc -l)
+ old_nutts=$(cat $data/utt2spk | wc -l)
+ if [ $new_nutts -ne $old_nutts ]; then
+ echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts"
+ else
+ echo "fix_data_dir.sh: kept all $old_nutts utterances."
+ fi
+ fi
+
+ for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do
+ if [ -f $data/$x ]; then
+ cp $data/$x $data/.backup/$x
+ if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then
+ utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x
+ fi
+ fi
+ done
+
+}
+
+filter_recordings
+filter_speakers
+filter_utts
+filter_speakers
+filter_recordings
+
+local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+
+echo "fix_data_dir.sh: old files are kept in $data/.backup"
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa-asr/local/format_wav_scp.py
new file mode 100755
index 0000000..1fd63d6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/format_wav_scp.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from io import BytesIO
+from pathlib import Path
+from typing import Tuple, Optional
+
+import kaldiio
+import humanfriendly
+import numpy as np
+import resampy
+import soundfile
+from tqdm import tqdm
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.read_text import read_2column_text
+from funasr.fileio.sound_scp import SoundScpWriter
+
+
+def humanfriendly_or_none(value: str):
+ if value in ("none", "None", "NONE"):
+ return None
+ return humanfriendly.parse_size(value)
+
+
+def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
+ """
+
+ >>> str2int_tuple('3,4,5')
+ (3, 4, 5)
+
+ """
+ assert check_argument_types()
+ if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
+ return None
+ return tuple(map(int, integers.strip().split(",")))
+
+
+def main():
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=logfmt)
+ logging.info(get_commandline_args())
+
+ parser = argparse.ArgumentParser(
+ description='Create waves list from "wav.scp"',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("scp")
+ parser.add_argument("outdir")
+ parser.add_argument(
+ "--name",
+ default="wav",
+ help="Specify the prefix word of output file name " 'such as "wav.scp"',
+ )
+ parser.add_argument("--segments", default=None)
+ parser.add_argument(
+ "--fs",
+ type=humanfriendly_or_none,
+ default=None,
+ help="If the sampling rate specified, " "Change the sampling rate.",
+ )
+ parser.add_argument("--audio-format", default="wav")
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument("--ref-channels", default=None, type=str2int_tuple)
+ group.add_argument("--utt2ref-channels", default=None, type=str)
+ args = parser.parse_args()
+
+ out_num_samples = Path(args.outdir) / f"utt2num_samples"
+
+ if args.ref_channels is not None:
+
+ def utt2ref_channels(x) -> Tuple[int, ...]:
+ return args.ref_channels
+
+ elif args.utt2ref_channels is not None:
+ utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
+
+ def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
+ chs_str = d[x]
+ return tuple(map(int, chs_str.split()))
+
+ else:
+ utt2ref_channels = None
+
+ Path(args.outdir).mkdir(parents=True, exist_ok=True)
+ out_wavscp = Path(args.outdir) / f"{args.name}.scp"
+ if args.segments is not None:
+ # Note: kaldiio supports only wav-pcm-int16le file.
+ loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
+ if args.audio_format.endswith("ark"):
+ fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+ fscp = out_wavscp.open("w")
+ else:
+ writer = SoundScpWriter(
+ args.outdir,
+ out_wavscp,
+ format=args.audio_format,
+ )
+
+ with out_num_samples.open("w") as fnum_samples:
+ for uttid, (rate, wave) in tqdm(loader):
+ # wave: (Time,) or (Time, Nmic)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is incompatible with Kaldi
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fscp,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+
+ else:
+ writer[uttid] = rate, wave
+ fnum_samples.write(f"{uttid} {len(wave)}\n")
+ else:
+ if args.audio_format.endswith("ark"):
+ fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+ else:
+ wavdir = Path(args.outdir) / f"data_{args.name}"
+ wavdir.mkdir(parents=True, exist_ok=True)
+
+ with Path(args.scp).open("r") as fscp, out_wavscp.open(
+ "w"
+ ) as fout, out_num_samples.open("w") as fnum_samples:
+ for line in tqdm(fscp):
+ uttid, wavpath = line.strip().split(None, 1)
+
+ if wavpath.endswith("|"):
+ # Streaming input e.g. cat a.wav |
+ with kaldiio.open_like_kaldi(wavpath, "rb") as f:
+ with BytesIO(f.read()) as g:
+ wave, rate = soundfile.read(g, dtype=np.int16)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is incompatible with Kaldi
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fout,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+ else:
+ owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+ soundfile.write(owavpath, wave, rate)
+ fout.write(f"{uttid} {owavpath}\n")
+ else:
+ wave, rate = soundfile.read(wavpath, dtype=np.int16)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+ save_asis = False
+
+ elif args.audio_format.endswith("ark"):
+ save_asis = False
+
+ elif Path(wavpath).suffix == "." + args.audio_format and (
+ args.fs is None or args.fs == rate
+ ):
+ save_asis = True
+
+ else:
+ save_asis = False
+
+ if save_asis:
+ # Neither --segments nor --fs are specified and
+ # the line doesn't end with "|",
+ # i.e. not using unix-pipe,
+ # only in this case,
+ # just using the original file as is.
+ fout.write(f"{uttid} {wavpath}\n")
+ else:
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is not supported in Kaldi.
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fout,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+ else:
+ owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+ soundfile.write(owavpath, wave, rate)
+ fout.write(f"{uttid} {owavpath}\n")
+ fnum_samples.write(f"{uttid} {len(wave)}\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa-asr/local/format_wav_scp.sh
new file mode 100755
index 0000000..04fc4a5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/format_wav_scp.sh
@@ -0,0 +1,142 @@
+#!/usr/bin/env bash
+set -euo pipefail
+SECONDS=0
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+help_message=$(cat << EOF
+Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
+e.g.
+$0 data/test/wav.scp data/test_format/
+
+Format 'wav.scp': In short words,
+changing "kaldi-datadir" to "modified-kaldi-datadir"
+
+The 'wav.scp' format in kaldi is very flexible,
+e.g. It can use unix-pipe as describing that wav file,
+but it sometime looks confusing and make scripts more complex.
+This tools creates actual wav files from 'wav.scp'
+and also segments wav files using 'segments'.
+
+Options
+ --fs <fs>
+ --segments <segments>
+ --nj <nj>
+ --cmd <cmd>
+EOF
+)
+
+out_filename=wav.scp
+cmd=utils/run.pl
+nj=30
+fs=none
+segments=
+
+ref_channels=
+utt2ref_channels=
+
+audio_format=wav
+write_utt2num_samples=true
+
+log "$0 $*"
+. utils/parse_options.sh
+
+if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
+ log "${help_message}"
+ log "Error: invalid command line arguments"
+ exit 1
+fi
+
+. ./path.sh # Setup the environment
+
+scp=$1
+if [ ! -f "${scp}" ]; then
+ log "${help_message}"
+ echo "$0: Error: No such file: ${scp}"
+ exit 1
+fi
+dir=$2
+
+
+if [ $# -eq 2 ]; then
+ logdir=${dir}/logs
+ outdir=${dir}/data
+
+elif [ $# -eq 3 ]; then
+ logdir=$3
+ outdir=${dir}/data
+
+elif [ $# -eq 4 ]; then
+ logdir=$3
+ outdir=$4
+fi
+
+
+mkdir -p ${logdir}
+
+rm -f "${dir}/${out_filename}"
+
+
+opts=
+if [ -n "${utt2ref_channels}" ]; then
+ opts="--utt2ref-channels ${utt2ref_channels} "
+elif [ -n "${ref_channels}" ]; then
+ opts="--ref-channels ${ref_channels} "
+fi
+
+
+if [ -n "${segments}" ]; then
+ log "[info]: using ${segments}"
+ nutt=$(<${segments} wc -l)
+ nj=$((nj<nutt?nj:nutt))
+
+ split_segments=""
+ for n in $(seq ${nj}); do
+ split_segments="${split_segments} ${logdir}/segments.${n}"
+ done
+
+ utils/split_scp.pl "${segments}" ${split_segments}
+
+ ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+ local/format_wav_scp.py \
+ ${opts} \
+ --fs ${fs} \
+ --audio-format "${audio_format}" \
+ "--segment=${logdir}/segments.JOB" \
+ "${scp}" "${outdir}/format.JOB"
+
+else
+ log "[info]: without segments"
+ nutt=$(<${scp} wc -l)
+ nj=$((nj<nutt?nj:nutt))
+
+ split_scps=""
+ for n in $(seq ${nj}); do
+ split_scps="${split_scps} ${logdir}/wav.${n}.scp"
+ done
+
+ utils/split_scp.pl "${scp}" ${split_scps}
+ ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+ local/format_wav_scp.py \
+ ${opts} \
+ --fs "${fs}" \
+ --audio-format "${audio_format}" \
+ "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
+fi
+
+# Workaround for the NFS problem
+ls ${outdir}/format.* > /dev/null
+
+# concatenate the .scp files together.
+for n in $(seq ${nj}); do
+ cat "${outdir}/format.${n}/wav.scp" || exit 1;
+done > "${dir}/${out_filename}" || exit 1
+
+if "${write_utt2num_samples}"; then
+ for n in $(seq ${nj}); do
+ cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
+ done > "${dir}/utt2num_samples" || exit 1
+fi
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
new file mode 100644
index 0000000..c37abf9
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
@@ -0,0 +1,167 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import sys
+import os
+import soundfile
+from itertools import permutations
+from sklearn.metrics.pairwise import cosine_similarity
+from sklearn import cluster
+
+
+def custom_spectral_clustering(affinity, min_n_clusters=2, max_n_clusters=4, refine=True,
+ threshold=0.995, laplacian_type="graph_cut"):
+ if refine:
+ # Symmetrization
+ affinity = np.maximum(affinity, np.transpose(affinity))
+ # Diffusion
+ affinity = np.matmul(affinity, np.transpose(affinity))
+ # Row-wise max normalization
+ row_max = affinity.max(axis=1, keepdims=True)
+ affinity = affinity / row_max
+
+ # a) Construct S and set diagonal elements to 0
+ affinity = affinity - np.diag(np.diag(affinity))
+ # b) Compute Laplacian matrix L and perform normalization:
+ degree = np.diag(np.sum(affinity, axis=1))
+ laplacian = degree - affinity
+ if laplacian_type == "random_walk":
+ degree_norm = np.diag(1 / (np.diag(degree) + 1e-10))
+ laplacian_norm = degree_norm.dot(laplacian)
+ else:
+ degree_half = np.diag(degree) ** 0.5 + 1e-15
+ laplacian_norm = laplacian / degree_half[:, np.newaxis] / degree_half
+
+ # c) Compute eigenvalues and eigenvectors of L_norm
+ eigenvalues, eigenvectors = np.linalg.eig(laplacian_norm)
+ eigenvalues = eigenvalues.real
+ eigenvectors = eigenvectors.real
+ index_array = np.argsort(eigenvalues)
+ eigenvalues = eigenvalues[index_array]
+ eigenvectors = eigenvectors[:, index_array]
+
+ # d) Compute the number of clusters k
+ k = min_n_clusters
+ for k in range(min_n_clusters, max_n_clusters + 1):
+ if eigenvalues[k] > threshold:
+ break
+ k = max(k, min_n_clusters)
+ spectral_embeddings = eigenvectors[:, :k]
+ # print(mid, k, eigenvalues[:10])
+
+ spectral_embeddings = spectral_embeddings / np.linalg.norm(spectral_embeddings, axis=1, ord=2, keepdims=True)
+ solver = cluster.KMeans(n_clusters=k, max_iter=1000, random_state=42)
+ solver.fit(spectral_embeddings)
+ return solver.labels_
+
+
+if __name__ == "__main__":
+ path = sys.argv[1] # dump2/raw/Eval_Ali_far
+ raw_path = sys.argv[2] # data/local/Eval_Ali_far
+ threshold = float(sys.argv[3]) # 0.996
+ sv_threshold = float(sys.argv[4]) # 0.815
+ wav_scp_file = open(path+'/wav.scp', 'r')
+ wav_scp = wav_scp_file.readlines()
+ wav_scp_file.close()
+ raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp = raw_meeting_scp_file.readlines()
+ raw_meeting_scp_file.close()
+ segments_scp_file = open(raw_path + '/segments', 'r')
+ segments_scp = segments_scp_file.readlines()
+ segments_scp_file.close()
+
+ segments_map = {}
+ for line in segments_scp:
+ line_list = line.strip().split(' ')
+ meeting = line_list[1]
+ seg = (float(line_list[-2]), float(line_list[-1]))
+ if meeting not in segments_map.keys():
+ segments_map[meeting] = [seg]
+ else:
+ segments_map[meeting].append(seg)
+
+ inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+ )
+
+ chunk_len = int(1.5*16000) # 1.5 seconds
+ hop_len = int(0.75*16000) # 0.75 seconds
+
+ os.system("mkdir -p " + path + "/cluster_profile_infer")
+ cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
+ meeting_map = {}
+ for line in raw_meeting_scp:
+ meeting = line.strip().split('\t')[0]
+ wav_path = line.strip().split('\t')[1]
+ wav = soundfile.read(wav_path)[0]
+ # take the first channel
+ if wav.ndim == 2:
+ wav=wav[:, 0]
+ # gen_seg_embedding
+ segments_list = segments_map[meeting]
+
+ # import ipdb;ipdb.set_trace()
+ all_seg_embedding_list = []
+ for seg in segments_list:
+ wav_seg = wav[int(seg[0] * 16000): int(seg[1] * 16000)]
+ wav_seg_len = wav_seg.shape[0]
+ i = 0
+ while i < wav_seg_len:
+ if i + chunk_len < wav_seg_len:
+ cur_wav_chunk = wav_seg[i: i+chunk_len]
+ else:
+ cur_wav_chunk=wav_seg[i: ]
+ # chunks under 0.2s are ignored
+ if cur_wav_chunk.shape[0] >= 0.2 * 16000:
+ cur_chunk_embedding = inference_sv_pipline(audio_in=cur_wav_chunk)["spk_embedding"]
+ all_seg_embedding_list.append(cur_chunk_embedding)
+ i += hop_len
+ all_seg_embedding = np.vstack(all_seg_embedding_list)
+ # all_seg_embedding (n, dim)
+
+ # compute affinity
+ affinity=cosine_similarity(all_seg_embedding)
+
+ affinity = np.maximum(affinity - sv_threshold, 0.0001) / (affinity.max() - sv_threshold)
+
+ # clustering
+ labels = custom_spectral_clustering(
+ affinity=affinity,
+ min_n_clusters=2,
+ max_n_clusters=4,
+ refine=True,
+ threshold=threshold,
+ laplacian_type="graph_cut")
+
+
+ cluster_dict={}
+ for j in range(labels.shape[0]):
+ if labels[j] not in cluster_dict.keys():
+ cluster_dict[labels[j]] = np.atleast_2d(all_seg_embedding[j])
+ else:
+ cluster_dict[labels[j]] = np.concatenate((cluster_dict[labels[j]], np.atleast_2d(all_seg_embedding[j])))
+
+ emb_list = []
+ # get cluster center
+ for k in cluster_dict.keys():
+ cluster_dict[k] = np.mean(cluster_dict[k], axis=0)
+ emb_list.append(cluster_dict[k])
+
+ spk_num = len(emb_list)
+ profile_for_infer = np.vstack(emb_list)
+ # save profile for each meeting
+ np.save(path + '/cluster_profile_infer/' + meeting + '.npy', profile_for_infer)
+ meeting_map[meeting] = (path + '/cluster_profile_infer/' + meeting + '.npy', spk_num)
+ cluster_spk_num_file.write(meeting + ' ' + str(spk_num) + '\n')
+ cluster_spk_num_file.flush()
+
+ cluster_spk_num_file.close()
+
+ profile_scp = open(path + "/cluster_profile_infer.scp", 'w')
+ for line in wav_scp:
+ uttid = line.strip().split(' ')[0]
+ meeting = uttid.split('-')[0]
+ profile_scp.write(uttid + ' ' + meeting_map[meeting][0] + '\n')
+ profile_scp.flush()
+ profile_scp.close()
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
new file mode 100644
index 0000000..18286b4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
@@ -0,0 +1,70 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+import sys
+import os
+import soundfile
+
+
+if __name__=="__main__":
+ path = sys.argv[1] # dump2/raw/Eval_Ali_far
+ raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
+ raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp = raw_meeting_scp_file.readlines()
+ raw_meeting_scp_file.close()
+ segments_scp_file = open(raw_path + '/segments', 'r')
+ segments_scp = segments_scp_file.readlines()
+ segments_scp_file.close()
+
+ oracle_emb_dir = path + '/oracle_embedding/'
+ os.system("mkdir -p " + oracle_emb_dir)
+ oracle_emb_scp_file = open(path+'/oracle_embedding.scp', 'w')
+
+ raw_wav_map = {}
+ for line in raw_meeting_scp:
+ meeting = line.strip().split('\t')[0]
+ wav_path = line.strip().split('\t')[1]
+ raw_wav_map[meeting] = wav_path
+
+ spk_map = {}
+ for line in segments_scp:
+ line_list = line.strip().split(' ')
+ meeting = line_list[1]
+ spk_id = line_list[0].split('_')[3]
+ spk = meeting + '_' + spk_id
+ time_start = float(line_list[-2])
+ time_end = float(line_list[-1])
+ if time_end - time_start > 0.5:
+ if spk not in spk_map.keys():
+ spk_map[spk] = [(int(time_start * 16000), int(time_end * 16000))]
+ else:
+ spk_map[spk].append((int(time_start * 16000), int(time_end * 16000)))
+
+ inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+ )
+
+ for spk in spk_map.keys():
+ meeting = spk.split('_SPK')[0]
+ wav_path = raw_wav_map[meeting]
+ wav = soundfile.read(wav_path)[0]
+ # take the first channel
+ if wav.ndim == 2:
+ wav = wav[:, 0]
+ all_seg_embedding_list=[]
+ # import ipdb;ipdb.set_trace()
+ for seg_time in spk_map[spk]:
+ if seg_time[0] < wav.shape[0] - 0.5 * 16000:
+ if seg_time[1] > wav.shape[0]:
+ cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: ])["spk_embedding"]
+ else:
+ cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg_time[0]: seg_time[1]])["spk_embedding"]
+ all_seg_embedding_list.append(cur_seg_embedding)
+ all_seg_embedding = np.vstack(all_seg_embedding_list)
+ spk_embedding = np.mean(all_seg_embedding, axis=0)
+ np.save(oracle_emb_dir + spk + '.npy', spk_embedding)
+ oracle_emb_scp_file.write(spk + ' ' + oracle_emb_dir + spk + '.npy' + '\n')
+ oracle_emb_scp_file.flush()
+
+ oracle_emb_scp_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
new file mode 100644
index 0000000..f44fcd4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
@@ -0,0 +1,59 @@
+import random
+import numpy as np
+import os
+import sys
+
+
+if __name__=="__main__":
+ path = sys.argv[1] # dump2/raw/Eval_Ali_far
+ wav_scp_file = open(path+"/wav.scp", 'r')
+ wav_scp = wav_scp_file.readlines()
+ wav_scp_file.close()
+ spk2id_file = open(path + "/spk2id", 'r')
+ spk2id = spk2id_file.readlines()
+ spk2id_file.close()
+ embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
+ embedding_scp = embedding_scp_file.readlines()
+ embedding_scp_file.close()
+
+ embedding_map = {}
+ for line in embedding_scp:
+ spk = line.strip().split(' ')[0]
+ if spk not in embedding_map.keys():
+ emb=np.load(line.strip().split(' ')[1])
+ embedding_map[spk] = emb
+
+ meeting_map_tmp = {}
+ global_spk_list = []
+ for line in spk2id:
+ line_list = line.strip().split(' ')
+ meeting = line_list[0].split('-')[0]
+ spk_id = line_list[0].split('-')[-1].split('_')[-1]
+ spk = meeting + '_' + spk_id
+ global_spk_list.append(spk)
+ if meeting in meeting_map_tmp.keys():
+ meeting_map_tmp[meeting].append(spk)
+ else:
+ meeting_map_tmp[meeting] = [spk]
+
+ meeting_map = {}
+ os.system('mkdir -p ' + path + '/oracle_profile_nopadding')
+ for meeting in meeting_map_tmp.keys():
+ emb_list = []
+ for i in range(len(meeting_map_tmp[meeting])):
+ spk = meeting_map_tmp[meeting][i]
+ emb_list.append(embedding_map[spk])
+ profile = np.vstack(emb_list)
+ np.save(path + '/oracle_profile_nopadding/' + meeting + '.npy', profile)
+ meeting_map[meeting] = path + '/oracle_profile_nopadding/' + meeting + '.npy'
+
+ profile_scp = open(path + '/oracle_profile_nopadding.scp', 'w')
+ profile_map_scp = open(path + '/oracle_profile_nopadding_spk_list', 'w')
+
+ for line in wav_scp:
+ uttid = line.strip().split(' ')[0]
+ meeting = uttid.split('-')[0]
+ profile_scp.write(uttid + ' ' + meeting_map[meeting] + '\n')
+ profile_map_scp.write(uttid + ' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
+ profile_scp.close()
+ profile_map_scp.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
new file mode 100644
index 0000000..b70a32a
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
@@ -0,0 +1,68 @@
+import random
+import numpy as np
+import os
+import sys
+
+
+if __name__=="__main__":
+ path = sys.argv[1] # dump2/raw/Train_Ali_far
+ wav_scp_file = open(path+"/wav.scp", 'r')
+ wav_scp = wav_scp_file.readlines()
+ wav_scp_file.close()
+ spk2id_file = open(path+"/spk2id", 'r')
+ spk2id = spk2id_file.readlines()
+ spk2id_file.close()
+ embedding_scp_file = open(path + "/oracle_embedding.scp", 'r')
+ embedding_scp = embedding_scp_file.readlines()
+ embedding_scp_file.close()
+
+ embedding_map = {}
+ for line in embedding_scp:
+ spk = line.strip().split(' ')[0]
+ if spk not in embedding_map.keys():
+ emb = np.load(line.strip().split(' ')[1])
+ embedding_map[spk] = emb
+
+ meeting_map_tmp = {}
+ global_spk_list = []
+ for line in spk2id:
+ line_list = line.strip().split(' ')
+ meeting = line_list[0].split('-')[0]
+ spk_id = line_list[0].split('-')[-1].split('_')[-1]
+ spk = meeting+'_' + spk_id
+ global_spk_list.append(spk)
+ if meeting in meeting_map_tmp.keys():
+ meeting_map_tmp[meeting].append(spk)
+ else:
+ meeting_map_tmp[meeting] = [spk]
+
+ for meeting in meeting_map_tmp.keys():
+ num = len(meeting_map_tmp[meeting])
+ if num < 4:
+ global_spk_list_tmp = global_spk_list[: ]
+ for spk in meeting_map_tmp[meeting]:
+ global_spk_list_tmp.remove(spk)
+ padding_spk = random.sample(global_spk_list_tmp, 4 - num)
+ meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk
+
+ meeting_map = {}
+ os.system('mkdir -p ' + path + '/oracle_profile_padding')
+ for meeting in meeting_map_tmp.keys():
+ emb_list = []
+ for i in range(len(meeting_map_tmp[meeting])):
+ spk = meeting_map_tmp[meeting][i]
+ emb_list.append(embedding_map[spk])
+ profile = np.vstack(emb_list)
+ np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile)
+ meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy'
+
+ profile_scp = open(path + '/oracle_profile_padding.scp', 'w')
+ profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w')
+
+ for line in wav_scp:
+ uttid = line.strip().split(' ')[0]
+ meeting = uttid.split('-')[0]
+ profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n')
+ profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n')
+ profile_scp.close()
+ profile_map_scp.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
new file mode 100755
index 0000000..1022ae6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
@@ -0,0 +1,116 @@
+#!/usr/bin/env bash
+
+# 2020 @kamo-naoyuki
+# This file was copied from Kaldi and
+# I deleted parts related to wav duration
+# because we shouldn't use kaldi's command here
+# and we don't need the files actually.
+
+# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
+# 2014 Tom Ko
+# 2018 Emotech LTD (author: Pawel Swietojanski)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+# wav.scp
+# spk2utt
+# utt2spk
+# text
+#
+# It generates the files which are used for perturbing the speed of the original data.
+
+export LC_ALL=C
+set -euo pipefail
+
+if [[ $# != 3 ]]; then
+ echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
+ echo "e.g.:"
+ echo " $0 0.9 data/train_si284 data/train_si284p"
+ exit 1
+fi
+
+factor=$1
+srcdir=$2
+destdir=$3
+label="sp"
+spk_prefix="${label}${factor}-"
+utt_prefix="${label}${factor}-"
+
+#check is sox on the path
+
+! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
+
+if [[ ! -f ${srcdir}/utt2spk ]]; then
+ echo "$0: no such file ${srcdir}/utt2spk"
+ exit 1;
+fi
+
+if [[ ${destdir} == "${srcdir}" ]]; then
+ echo "$0: this script requires <srcdir> and <destdir> to be different."
+ exit 1
+fi
+
+mkdir -p "${destdir}"
+
+<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
+<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
+<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
+if [[ ! -f ${srcdir}/utt2uniq ]]; then
+ <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
+else
+ <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
+fi
+
+
+<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \
+ local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+
+local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+
+if [[ -f ${srcdir}/segments ]]; then
+
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+ local/apply_map.pl -f 2 "${destdir}"/reco_map | \
+ awk -v factor="${factor}" \
+ '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
+ >"${destdir}"/segments
+
+ local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+ awk -v factor="${factor}" \
+ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+ else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+ else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+ > "${destdir}"/wav.scp
+ if [[ -f ${srcdir}/reco2file_and_channel ]]; then
+ local/apply_map.pl -f 1 "${destdir}"/reco_map \
+ <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
+ fi
+
+else # no segments->wav indexed by utterance.
+ if [[ -f ${srcdir}/wav.scp ]]; then
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+ awk -v factor="${factor}" \
+ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+ else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+ else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+ > "${destdir}"/wav.scp
+ fi
+fi
+
+if [[ -f ${srcdir}/text ]]; then
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+fi
+if [[ -f ${srcdir}/spk2gender ]]; then
+ local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+fi
+if [[ -f ${srcdir}/utt2lang ]]; then
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+fi
+
+rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
+echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
+
+local/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
new file mode 100755
index 0000000..d900bb1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ args = parser.parse_args()
+ return args
+
+class Segment(object):
+ def __init__(self, uttid, text):
+ self.uttid = uttid
+ self.text = text
+
+def main(args):
+ text = codecs.open(Path(args.path) / "text", "r", "utf-8")
+ spk2utt = codecs.open(Path(args.path) / "spk2utt", "r", "utf-8")
+ utt2spk = codecs.open(Path(args.path) / "utt2spk_all_fifo", "r", "utf-8")
+ spk2id = codecs.open(Path(args.path) / "spk2id", "w", "utf-8")
+
+ spkid_map = {}
+ meetingid_map = {}
+ for line in spk2utt:
+ spkid = line.strip().split(" ")[0]
+ meeting_id_list = spkid.split("_")[:3]
+ meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+ if meeting_id not in meetingid_map:
+ meetingid_map[meeting_id] = 1
+ else:
+ meetingid_map[meeting_id] += 1
+ spkid_map[spkid] = meetingid_map[meeting_id]
+ spk2id.write("%s %s\n" % (spkid, meetingid_map[meeting_id]))
+
+ utt2spklist = {}
+ for line in utt2spk:
+ uttid = line.strip().split(" ")[0]
+ spkid = line.strip().split(" ")[1]
+ spklist = spkid.split("$")
+ tmp = []
+ for index in range(len(spklist)):
+ tmp.append(spkid_map[spklist[index]])
+ utt2spklist[uttid] = tmp
+ # parse the textgrid file for each utterance
+ all_segments = []
+ for line in text:
+ uttid = line.strip().split(" ")[0]
+ context = line.strip().split(" ")[1]
+ spklist = utt2spklist[uttid]
+ length_text = len(context)
+ cnt = 0
+ tmp_text = ""
+ for index in range(length_text):
+ if context[index] != "$":
+ tmp_text += str(spklist[cnt])
+ else:
+ tmp_text += "$"
+ cnt += 1
+ tmp_seg = Segment(uttid,tmp_text)
+ all_segments.append(tmp_seg)
+
+ text.close()
+ utt2spk.close()
+ spk2utt.close()
+ spk2id.close()
+
+ text_id = codecs.open(Path(args.path) / "text_id", "w", "utf-8")
+
+ for i in range(len(all_segments)):
+ uttid_tmp = all_segments[i].uttid
+ text_tmp = all_segments[i].text
+
+ text_id.write("%s %s\n" % (uttid_tmp, text_tmp))
+
+ text_id.close()
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa-asr/local/process_text_id.py
new file mode 100644
index 0000000..0a9506e
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_text_id.py
@@ -0,0 +1,24 @@
+import sys
+if __name__=="__main__":
+ path=sys.argv[1]
+
+ text_id_old_file=open(path+"/text_id",'r')
+ text_id_old=text_id_old_file.readlines()
+ text_id_old_file.close()
+
+ text_id=open(path+"/text_id_train",'w')
+ for line in text_id_old:
+ uttid=line.strip().split(' ')[0]
+ old_id=line.strip().split(' ')[1]
+ pre_id='0'
+ new_id_list=[]
+ for i in old_id:
+ if i == '$':
+ new_id_list.append(pre_id)
+ else:
+ new_id_list.append(str(int(i)-1))
+ pre_id=str(int(i)-1)
+ new_id_list.append(pre_id)
+ new_id=' '.join(new_id_list)
+ text_id.write(uttid+' '+new_id+'\n')
+ text_id.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py
new file mode 100644
index 0000000..f15d509
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_text_spk_merge.py
@@ -0,0 +1,55 @@
+import sys
+
+
+if __name__ == "__main__":
+ path=sys.argv[1]
+ text_scp_file = open(path + '/text', 'r')
+ text_scp = text_scp_file.readlines()
+ text_scp_file.close()
+ text_id_scp_file = open(path + '/text_id', 'r')
+ text_id_scp = text_id_scp_file.readlines()
+ text_id_scp_file.close()
+ text_spk_merge_file = open(path + '/text_spk_merge', 'w')
+ assert len(text_scp) == len(text_id_scp)
+
+ meeting_map = {} # {meeting_id: [(start_time, text, text_id), (start_time, text, text_id), ...]}
+ for i in range(len(text_scp)):
+ text_line = text_scp[i].strip().split(' ')
+ text_id_line = text_id_scp[i].strip().split(' ')
+ assert text_line[0] == text_id_line[0]
+ if len(text_line) > 1:
+ uttid = text_line[0]
+ text = text_line[1]
+ text_id = text_id_line[1]
+ meeting_id = uttid.split('-')[0]
+ start_time = int(uttid.split('-')[-2])
+ if meeting_id not in meeting_map:
+ meeting_map[meeting_id] = [(start_time,text,text_id)]
+ else:
+ meeting_map[meeting_id].append((start_time,text,text_id))
+
+ for meeting_id in sorted(meeting_map.keys()):
+ cur_meeting_list = sorted(meeting_map[meeting_id], key=lambda x: x[0])
+ text_spk_merge_map = {} #{1: text1, 2: text2, ...}
+ for cur_utt in cur_meeting_list:
+ cur_text = cur_utt[1]
+ cur_text_id = cur_utt[2]
+ assert len(cur_text)==len(cur_text_id)
+ if len(cur_text) != 0:
+ cur_text_split = cur_text.split('$')
+ cur_text_id_split = cur_text_id.split('$')
+ assert len(cur_text_split) == len(cur_text_id_split)
+ for i in range(len(cur_text_split)):
+ if len(cur_text_split[i]) != 0:
+ spk_id = int(cur_text_id_split[i][0])
+ if spk_id not in text_spk_merge_map.keys():
+ text_spk_merge_map[spk_id] = cur_text_split[i]
+ else:
+ text_spk_merge_map[spk_id] += cur_text_split[i]
+ text_spk_merge_list = []
+ for spk_id in sorted(text_spk_merge_map.keys()):
+ text_spk_merge_list.append(text_spk_merge_map[spk_id])
+ text_spk_merge_file.write(meeting_id + ' ' + '$'.join(text_spk_merge_list) + '\n')
+ text_spk_merge_file.flush()
+
+ text_spk_merge_file.close()
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
new file mode 100755
index 0000000..fdf2460
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+Process the textgrid files
+"""
+import argparse
+import codecs
+from distutils.util import strtobool
+from pathlib import Path
+import textgrid
+import pdb
+import numpy as np
+import sys
+import math
+
+
+class Segment(object):
+ def __init__(self, uttid, spkr, stime, etime, text):
+ self.uttid = uttid
+ self.spkr = spkr
+ self.stime = round(stime, 2)
+ self.etime = round(etime, 2)
+ self.text = text
+
+ def change_stime(self, time):
+ self.stime = time
+
+ def change_etime(self, time):
+ self.etime = time
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="process the textgrid files")
+ parser.add_argument("--path", type=str, required=True, help="Data path")
+ args = parser.parse_args()
+ return args
+
+
+
+def main(args):
+ textgrid_flist = codecs.open(Path(args.path) / "textgrid.flist", "r", "utf-8")
+ segment_file = codecs.open(Path(args.path)/"segments", "w", "utf-8")
+ utt2spk = codecs.open(Path(args.path)/"utt2spk", "w", "utf-8")
+
+ # get the path of textgrid file for each utterance
+ for line in textgrid_flist:
+ line_array = line.strip().split(" ")
+ path = Path(line_array[1])
+ uttid = line_array[0]
+
+ try:
+ tg = textgrid.TextGrid.fromFile(path)
+ except:
+ pdb.set_trace()
+ num_spk = tg.__len__()
+ spk2textgrid = {}
+ spk2weight = {}
+ weight2spk = {}
+ cnt = 2
+ xmax = 0
+ for i in range(tg.__len__()):
+ spk_name = tg[i].name
+ if spk_name not in spk2weight:
+ spk2weight[spk_name] = cnt
+ weight2spk[cnt] = spk_name
+ cnt = cnt * 2
+ segments = []
+ for j in range(tg[i].__len__()):
+ if tg[i][j].mark:
+ if xmax < tg[i][j].maxTime:
+ xmax = tg[i][j].maxTime
+ segments.append(
+ Segment(
+ uttid,
+ tg[i].name,
+ tg[i][j].minTime,
+ tg[i][j].maxTime,
+ tg[i][j].mark.strip(),
+ )
+ )
+ segments = sorted(segments, key=lambda x: x.stime)
+ spk2textgrid[spk_name] = segments
+ olp_label = np.zeros((num_spk, int(xmax/0.01)), dtype=np.int32)
+ for spkid in spk2weight.keys():
+ weight = spk2weight[spkid]
+ segments = spk2textgrid[spkid]
+ idx = int(math.log2(weight) )- 1
+ for i in range(len(segments)):
+ stime = segments[i].stime
+ etime = segments[i].etime
+ olp_label[idx, int(stime/0.01): int(etime/0.01)] = weight
+ sum_label = olp_label.sum(axis=0)
+ stime = 0
+ pre_value = 0
+ for pos in range(sum_label.shape[0]):
+ if sum_label[pos] in weight2spk:
+ if pre_value in weight2spk:
+ if sum_label[pos] != pre_value:
+ spkids = weight2spk[pre_value]
+ spkid_array = spkids.split("_")
+ spkid = spkid_array[-1]
+ #spkid = uttid+spkid
+ if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
+ segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
+ utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
+ stime = pos
+ pre_value = sum_label[pos]
+ else:
+ stime = pos
+ pre_value = sum_label[pos]
+ else:
+ if pre_value in weight2spk:
+ spkids = weight2spk[pre_value]
+ spkid_array = spkids.split("_")
+ spkid = spkid_array[-1]
+ #spkid = uttid+spkid
+ if round(stime*0.01, 2) != round((pos-1)*0.01, 2):
+ segment_file.write("%s_%s_%s_%s %s %s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid, round(stime*0.01, 2) ,round((pos-1)*0.01, 2)))
+ utt2spk.write("%s_%s_%s_%s %s\n" % (uttid, spkid, str(int(stime)).zfill(7), str(int(pos-1)).zfill(7), uttid+"_"+spkid))
+ stime = pos
+ pre_value = sum_label[pos]
+ textgrid_flist.close()
+ segment_file.close()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
new file mode 100755
index 0000000..23992f2
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
@@ -0,0 +1,27 @@
+#!/usr/bin/env perl
+# Copyright 2010-2011 Microsoft Corporation
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+while(<>){
+ @A = split(" ", $_);
+ @A > 1 || die "Invalid line in spk2utt file: $_";
+ $s = shift @A;
+ foreach $u ( @A ) {
+ print "$u $s\n";
+ }
+}
+
+
diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa-asr/local/text_format.pl
new file mode 100755
index 0000000..45f1f64
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/text_format.pl
@@ -0,0 +1,14 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright Chao Weng
+
+# normalizations for hkust trascript
+# see the docs/trans-guidelines.pdf for details
+
+while (<STDIN>) {
+ @A = split(" ", $_);
+ if (@A == 1) {
+ next;
+ }
+ print $_
+}
diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa-asr/local/text_normalize.pl
new file mode 100755
index 0000000..ac301d4
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/text_normalize.pl
@@ -0,0 +1,38 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# Copyright Chao Weng
+
+# normalizations for hkust trascript
+# see the docs/trans-guidelines.pdf for details
+
+while (<STDIN>) {
+ @A = split(" ", $_);
+ print "$A[0] ";
+ for ($n = 1; $n < @A; $n++) {
+ $tmp = $A[$n];
+ if ($tmp =~ /<sil>/) {$tmp =~ s:<sil>::g;}
+ if ($tmp =~ /<%>/) {$tmp =~ s:<%>::g;}
+ if ($tmp =~ /<->/) {$tmp =~ s:<->::g;}
+ if ($tmp =~ /<\$>/) {$tmp =~ s:<\$>::g;}
+ if ($tmp =~ /<#>/) {$tmp =~ s:<#>::g;}
+ if ($tmp =~ /<_>/) {$tmp =~ s:<_>::g;}
+ if ($tmp =~ /<space>/) {$tmp =~ s:<space>::g;}
+ if ($tmp =~ /`/) {$tmp =~ s:`::g;}
+ if ($tmp =~ /&/) {$tmp =~ s:&::g;}
+ if ($tmp =~ /,/) {$tmp =~ s:,::g;}
+ if ($tmp =~ /[a-zA-Z]/) {$tmp=uc($tmp);}
+ if ($tmp =~ /锛�/) {$tmp =~ s:锛�:A:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:A:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:B:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:C:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:K:g;}
+ if ($tmp =~ /锝�/) {$tmp =~ s:锝�:T:g;}
+ if ($tmp =~ /锛�/) {$tmp =~ s:锛�::g;}
+ if ($tmp =~ /涓�/) {$tmp =~ s:涓�::g;}
+ if ($tmp =~ /銆�/) {$tmp =~ s:銆�::g;}
+ if ($tmp =~ /銆�/) {$tmp =~ s:銆�::g;}
+ if ($tmp =~ /锛�/) {$tmp =~ s:锛�::g;}
+ print "$tmp ";
+ }
+ print "\n";
+}
diff --git a/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
new file mode 100755
index 0000000..6e0e438
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
@@ -0,0 +1,38 @@
+#!/usr/bin/env perl
+# Copyright 2010-2011 Microsoft Corporation
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+# converts an utt2spk file to a spk2utt file.
+# Takes input from the stdin or from a file argument;
+# output goes to the standard out.
+
+if ( @ARGV > 1 ) {
+ die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt";
+}
+
+while(<>){
+ @A = split(" ", $_);
+ @A == 2 || die "Invalid line in utt2spk file: $_";
+ ($u,$s) = @A;
+ if(!$seen_spk{$s}) {
+ $seen_spk{$s} = 1;
+ push @spklist, $s;
+ }
+ push (@{$spk_hash{$s}}, "$u");
+}
+foreach $s (@spklist) {
+ $l = join(' ',@{$spk_hash{$s}});
+ print "$s $l\n";
+}
diff --git a/egs/alimeeting/sa-asr/local/validate_data_dir.sh b/egs/alimeeting/sa-asr/local/validate_data_dir.sh
new file mode 100755
index 0000000..37c99ae
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/validate_data_dir.sh
@@ -0,0 +1,404 @@
+#!/usr/bin/env bash
+
+cmd="$@"
+
+no_feats=false
+no_wav=false
+no_text=false
+no_spk_sort=false
+non_print=false
+
+
+function show_help
+{
+ echo "Usage: $0 [--no-feats] [--no-text] [--non-print] [--no-wav] [--no-spk-sort] <data-dir>"
+ echo "The --no-xxx options mean that the script does not require "
+ echo "xxx.scp to be present, but it will check it if it is present."
+ echo "--no-spk-sort means that the script does not require the utt2spk to be "
+ echo "sorted by the speaker-id in addition to being sorted by utterance-id."
+ echo "--non-print ignore the presence of non-printable characters."
+ echo "By default, utt2spk is expected to be sorted by both, which can be "
+ echo "achieved by making the speaker-id prefixes of the utterance-ids"
+ echo "e.g.: $0 data/train"
+}
+
+while [ $# -ne 0 ] ; do
+ case "$1" in
+ "--no-feats")
+ no_feats=true;
+ ;;
+ "--no-text")
+ no_text=true;
+ ;;
+ "--non-print")
+ non_print=true;
+ ;;
+ "--no-wav")
+ no_wav=true;
+ ;;
+ "--no-spk-sort")
+ no_spk_sort=true;
+ ;;
+ *)
+ if ! [ -z "$data" ] ; then
+ show_help;
+ exit 1
+ fi
+ data=$1
+ ;;
+ esac
+ shift
+done
+
+
+
+if [ ! -d $data ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+if [ -f $data/images.scp ]; then
+ cmd=${cmd/--no-wav/} # remove --no-wav if supplied
+ image/validate_data_dir.sh $cmd
+ exit $?
+fi
+
+for f in spk2utt utt2spk; do
+ if [ ! -f $data/$f ]; then
+ echo "$0: no such file $f"
+ exit 1;
+ fi
+ if [ ! -s $data/$f ]; then
+ echo "$0: empty file $f"
+ exit 1;
+ fi
+done
+
+! cat $data/utt2spk | awk '{if (NF != 2) exit(1); }' && \
+ echo "$0: $data/utt2spk has wrong format." && exit;
+
+ns=$(wc -l < $data/spk2utt)
+if [ "$ns" == 1 ]; then
+ echo "$0: WARNING: you have only one speaker. This probably a bad idea."
+ echo " Search for the word 'bold' in http://kaldi-asr.org/doc/data_prep.html"
+ echo " for more information."
+fi
+
+
+tmpdir=$(mktemp -d /tmp/kaldi.XXXX);
+trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM
+
+export LC_ALL=C
+
+function check_sorted_and_uniq {
+ ! perl -ne '((substr $_,-1) eq "\n") or die "file $ARGV has invalid newline";' $1 && exit 1;
+ ! awk '{print $1}' < $1 | sort -uC && echo "$0: file $1 is not sorted or has duplicates" && exit 1;
+}
+
+function partial_diff {
+ diff -U1 $1 $2 | (head -n 6; echo "..."; tail -n 6)
+ n1=`cat $1 | wc -l`
+ n2=`cat $2 | wc -l`
+ echo "[Lengths are $1=$n1 versus $2=$n2]"
+}
+
+check_sorted_and_uniq $data/utt2spk
+
+if ! $no_spk_sort; then
+ ! sort -k2 -C $data/utt2spk && \
+ echo "$0: utt2spk is not in sorted order when sorted first on speaker-id " && \
+ echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1;
+fi
+
+check_sorted_and_uniq $data/spk2utt
+
+! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
+ <(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \
+ echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
+
+cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
+
+if [ ! -f $data/text ] && ! $no_text; then
+ echo "$0: no such file $data/text (if this is by design, specify --no-text)"
+ exit 1;
+fi
+
+num_utts=`cat $tmpdir/utts | wc -l`
+if ! $no_text; then
+ if ! $non_print; then
+ if locale -a | grep "C.UTF-8" >/dev/null; then
+ L=C.UTF-8
+ else
+ L=en_US.UTF-8
+ fi
+ n_non_print=$(LC_ALL="$L" grep -c '[^[:print:][:space:]]' $data/text) && \
+ echo "$0: text contains $n_non_print lines with non-printable characters" &&\
+ exit 1;
+ fi
+ local/validate_text.pl $data/text || exit 1;
+ check_sorted_and_uniq $data/text
+ text_len=`cat $data/text | wc -l`
+ illegal_sym_list="<s> </s> #0"
+ for x in $illegal_sym_list; do
+ if grep -w "$x" $data/text > /dev/null; then
+ echo "$0: Error: in $data, text contains illegal symbol $x"
+ exit 1;
+ fi
+ done
+ awk '{print $1}' < $data/text > $tmpdir/utts.txt
+ if ! cmp -s $tmpdir/utts{,.txt}; then
+ echo "$0: Error: in $data, utterance lists extracted from utt2spk and text"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.txt}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/segments ] && [ ! -f $data/wav.scp ]; then
+ echo "$0: in directory $data, segments file exists but no wav.scp"
+ exit 1;
+fi
+
+
+if [ ! -f $data/wav.scp ] && ! $no_wav; then
+ echo "$0: no such file $data/wav.scp (if this is by design, specify --no-wav)"
+ exit 1;
+fi
+
+if [ -f $data/wav.scp ]; then
+ check_sorted_and_uniq $data/wav.scp
+
+ if grep -E -q '^\S+\s+~' $data/wav.scp; then
+ # note: it's not a good idea to have any kind of tilde in wav.scp, even if
+ # part of a command, as it would cause compatibility problems if run by
+ # other users, but this used to be not checked for so we let it slide unless
+ # it's something of the form "foo ~/foo.wav" (i.e. a plain file name) which
+ # would definitely cause problems as the fopen system call does not do
+ # tilde expansion.
+ echo "$0: Please do not use tilde (~) in your wav.scp."
+ exit 1;
+ fi
+
+ if [ -f $data/segments ]; then
+
+ check_sorted_and_uniq $data/segments
+ # We have a segments file -> interpret wav file as "recording-ids" not utterance-ids.
+ ! cat $data/segments | \
+ awk '{if (NF != 4 || $4 <= $3) { print "Bad line in segments file", $0; exit(1); }}' && \
+ echo "$0: badly formatted segments file" && exit 1;
+
+ segments_len=`cat $data/segments | wc -l`
+ if [ -f $data/text ]; then
+ ! cmp -s $tmpdir/utts <(awk '{print $1}' <$data/segments) && \
+ echo "$0: Utterance list differs between $data/utt2spk and $data/segments " && \
+ echo "$0: Lengths are $segments_len vs $num_utts" && \
+ exit 1
+ fi
+
+ cat $data/segments | awk '{print $2}' | sort | uniq > $tmpdir/recordings
+ awk '{print $1}' $data/wav.scp > $tmpdir/recordings.wav
+ if ! cmp -s $tmpdir/recordings{,.wav}; then
+ echo "$0: Error: in $data, recording-ids extracted from segments and wav.scp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/recordings{,.wav}
+ exit 1;
+ fi
+ if [ -f $data/reco2file_and_channel ]; then
+ # this file is needed only for ctm scoring; it's indexed by recording-id.
+ check_sorted_and_uniq $data/reco2file_and_channel
+ ! cat $data/reco2file_and_channel | \
+ awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
+ if ( NF == 3 && $3 == "1" ) {
+ warning_issued = 1;
+ } else {
+ print "Bad line ", $0; exit 1;
+ }
+ }
+ }
+ END {
+ if (warning_issued == 1) {
+ print "The channel should be marked as A or B, not 1! You should change it ASAP! "
+ }
+ }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
+ cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/recordings.r2fc
+ if ! cmp -s $tmpdir/recordings{,.r2fc}; then
+ echo "$0: Error: in $data, recording-ids extracted from segments and reco2file_and_channel"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/recordings{,.r2fc}
+ exit 1;
+ fi
+ fi
+ else
+ # No segments file -> assume wav.scp indexed by utterance.
+ cat $data/wav.scp | awk '{print $1}' > $tmpdir/utts.wav
+ if ! cmp -s $tmpdir/utts{,.wav}; then
+ echo "$0: Error: in $data, utterance lists extracted from utt2spk and wav.scp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.wav}
+ exit 1;
+ fi
+
+ if [ -f $data/reco2file_and_channel ]; then
+ # this file is needed only for ctm scoring; it's indexed by recording-id.
+ check_sorted_and_uniq $data/reco2file_and_channel
+ ! cat $data/reco2file_and_channel | \
+ awk '{if (NF != 3 || ($3 != "A" && $3 != "B" )) {
+ if ( NF == 3 && $3 == "1" ) {
+ warning_issued = 1;
+ } else {
+ print "Bad line ", $0; exit 1;
+ }
+ }
+ }
+ END {
+ if (warning_issued == 1) {
+ print "The channel should be marked as A or B, not 1! You should change it ASAP! "
+ }
+ }' && echo "$0: badly formatted reco2file_and_channel file" && exit 1;
+ cat $data/reco2file_and_channel | awk '{print $1}' > $tmpdir/utts.r2fc
+ if ! cmp -s $tmpdir/utts{,.r2fc}; then
+ echo "$0: Error: in $data, utterance-ids extracted from segments and reco2file_and_channel"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.r2fc}
+ exit 1;
+ fi
+ fi
+ fi
+fi
+
+if [ ! -f $data/feats.scp ] && ! $no_feats; then
+ echo "$0: no such file $data/feats.scp (if this is by design, specify --no-feats)"
+ exit 1;
+fi
+
+if [ -f $data/feats.scp ]; then
+ check_sorted_and_uniq $data/feats.scp
+ cat $data/feats.scp | awk '{print $1}' > $tmpdir/utts.feats
+ if ! cmp -s $tmpdir/utts{,.feats}; then
+ echo "$0: Error: in $data, utterance-ids extracted from utt2spk and features"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.feats}
+ exit 1;
+ fi
+fi
+
+
+if [ -f $data/cmvn.scp ]; then
+ check_sorted_and_uniq $data/cmvn.scp
+ cat $data/cmvn.scp | awk '{print $1}' > $tmpdir/speakers.cmvn
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ if ! cmp -s $tmpdir/speakers{,.cmvn}; then
+ echo "$0: Error: in $data, speaker lists extracted from spk2utt and cmvn"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/speakers{,.cmvn}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/spk2gender ]; then
+ check_sorted_and_uniq $data/spk2gender
+ ! cat $data/spk2gender | awk '{if (!((NF == 2 && ($2 == "m" || $2 == "f")))) exit 1; }' && \
+ echo "$0: Mal-formed spk2gender file" && exit 1;
+ cat $data/spk2gender | awk '{print $1}' > $tmpdir/speakers.spk2gender
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ if ! cmp -s $tmpdir/speakers{,.spk2gender}; then
+ echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2gender"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/speakers{,.spk2gender}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/spk2warp ]; then
+ check_sorted_and_uniq $data/spk2warp
+ ! cat $data/spk2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
+ echo "$0: Mal-formed spk2warp file" && exit 1;
+ cat $data/spk2warp | awk '{print $1}' > $tmpdir/speakers.spk2warp
+ cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
+ if ! cmp -s $tmpdir/speakers{,.spk2warp}; then
+ echo "$0: Error: in $data, speaker lists extracted from spk2utt and spk2warp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/speakers{,.spk2warp}
+ exit 1;
+ fi
+fi
+
+if [ -f $data/utt2warp ]; then
+ check_sorted_and_uniq $data/utt2warp
+ ! cat $data/utt2warp | awk '{if (!((NF == 2 && ($2 > 0.5 && $2 < 1.5)))){ print; exit 1; }}' && \
+ echo "$0: Mal-formed utt2warp file" && exit 1;
+ cat $data/utt2warp | awk '{print $1}' > $tmpdir/utts.utt2warp
+ cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts
+ if ! cmp -s $tmpdir/utts{,.utt2warp}; then
+ echo "$0: Error: in $data, utterance lists extracted from utt2spk and utt2warp"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.utt2warp}
+ exit 1;
+ fi
+fi
+
+# check some optionally-required things
+for f in vad.scp utt2lang utt2uniq; do
+ if [ -f $data/$f ]; then
+ check_sorted_and_uniq $data/$f
+ if ! cmp -s <( awk '{print $1}' $data/utt2spk ) \
+ <( awk '{print $1}' $data/$f ); then
+ echo "$0: error: in $data, $f and utt2spk do not have identical utterance-id list"
+ exit 1;
+ fi
+ fi
+done
+
+
+if [ -f $data/utt2dur ]; then
+ check_sorted_and_uniq $data/utt2dur
+ cat $data/utt2dur | awk '{print $1}' > $tmpdir/utts.utt2dur
+ if ! cmp -s $tmpdir/utts{,.utt2dur}; then
+ echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2dur file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.utt2dur}
+ exit 1;
+ fi
+ cat $data/utt2dur | \
+ awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line utt2dur:" NR ":" $0; exit(1) }}' || exit 1
+fi
+
+if [ -f $data/utt2num_frames ]; then
+ check_sorted_and_uniq $data/utt2num_frames
+ cat $data/utt2num_frames | awk '{print $1}' > $tmpdir/utts.utt2num_frames
+ if ! cmp -s $tmpdir/utts{,.utt2num_frames}; then
+ echo "$0: Error: in $data, utterance-ids extracted from utt2spk and utt2num_frames file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/utts{,.utt2num_frames}
+ exit 1
+ fi
+ awk <$data/utt2num_frames '{
+ if (NF != 2 || !($2 > 0) || $2 != int($2)) {
+ print "Bad line utt2num_frames:" NR ":" $0
+ exit 1 } }' || exit 1
+fi
+
+if [ -f $data/reco2dur ]; then
+ check_sorted_and_uniq $data/reco2dur
+ cat $data/reco2dur | awk '{print $1}' > $tmpdir/recordings.reco2dur
+ if [ -f $tmpdir/recordings ]; then
+ if ! cmp -s $tmpdir/recordings{,.reco2dur}; then
+ echo "$0: Error: in $data, recording-ids extracted from segments and reco2dur file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/recordings{,.reco2dur}
+ exit 1;
+ fi
+ else
+ if ! cmp -s $tmpdir/{utts,recordings.reco2dur}; then
+ echo "$0: Error: in $data, recording-ids extracted from wav.scp and reco2dur file"
+ echo "$0: differ, partial diff is:"
+ partial_diff $tmpdir/{utts,recordings.reco2dur}
+ exit 1;
+ fi
+ fi
+ cat $data/reco2dur | \
+ awk '{ if (NF != 2 || !($2 > 0)) { print "Bad line : " $0; exit(1) }}' || exit 1
+fi
+
+
+echo "$0: Successfully validated data-directory $data"
diff --git a/egs/alimeeting/sa-asr/local/validate_text.pl b/egs/alimeeting/sa-asr/local/validate_text.pl
new file mode 100755
index 0000000..7f75cf1
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/validate_text.pl
@@ -0,0 +1,136 @@
+#!/usr/bin/env perl
+#
+#===============================================================================
+# Copyright 2017 Johns Hopkins University (author: Yenda Trmal <jtrmal@gmail.com>)
+# Johns Hopkins University (author: Daniel Povey)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+#===============================================================================
+
+# validation script for data/<dataset>/text
+# to be called (preferably) from utils/validate_data_dir.sh
+use strict;
+use warnings;
+use utf8;
+use Fcntl qw< SEEK_SET >;
+
+# this function reads the opened file (supplied as a first
+# parameter) into an array of lines. For each
+# line, it tests whether it's a valid utf-8 compatible
+# line. If all lines are valid utf-8, it returns the lines
+# decoded as utf-8, otherwise it assumes the file's encoding
+# is one of those 1-byte encodings, such as ISO-8859-x
+# or Windows CP-X.
+# Please recall we do not really care about
+# the actually encoding, we just need to
+# make sure the length of the (decoded) string
+# is correct (to make the output formatting looking right).
+sub get_utf8_or_bytestream {
+ use Encode qw(decode encode);
+ my $is_utf_compatible = 1;
+ my @unicode_lines;
+ my @raw_lines;
+ my $raw_text;
+ my $lineno = 0;
+ my $file = shift;
+
+ while (<$file>) {
+ $raw_text = $_;
+ last unless $raw_text;
+ if ($is_utf_compatible) {
+ my $decoded_text = eval { decode("UTF-8", $raw_text, Encode::FB_CROAK) } ;
+ $is_utf_compatible = $is_utf_compatible && defined($decoded_text);
+ push @unicode_lines, $decoded_text;
+ } else {
+ #print STDERR "WARNING: the line $raw_text cannot be interpreted as UTF-8: $decoded_text\n";
+ ;
+ }
+ push @raw_lines, $raw_text;
+ $lineno += 1;
+ }
+
+ if (!$is_utf_compatible) {
+ return (0, @raw_lines);
+ } else {
+ return (1, @unicode_lines);
+ }
+}
+
+# check if the given unicode string contain unicode whitespaces
+# other than the usual four: TAB, LF, CR and SPACE
+sub validate_utf8_whitespaces {
+ my $unicode_lines = shift;
+ use feature 'unicode_strings';
+ for (my $i = 0; $i < scalar @{$unicode_lines}; $i++) {
+ my $current_line = $unicode_lines->[$i];
+ if ((substr $current_line, -1) ne "\n"){
+ print STDERR "$0: The current line (nr. $i) has invalid newline\n";
+ return 1;
+ }
+ my @A = split(" ", $current_line);
+ my $utt_id = $A[0];
+ # we replace TAB, LF, CR, and SPACE
+ # this is to simplify the test
+ if ($current_line =~ /\x{000d}/) {
+ print STDERR "$0: The line for utterance $utt_id contains CR (0x0D) character\n";
+ return 1;
+ }
+ $current_line =~ s/[\x{0009}\x{000a}\x{0020}]/./g;
+ if ($current_line =~/\s/) {
+ print STDERR "$0: The line for utterance $utt_id contains disallowed Unicode whitespaces\n";
+ return 1;
+ }
+ }
+ return 0;
+}
+
+# checks if the text in the file (supplied as the argument) is utf-8 compatible
+# if yes, checks if it contains only allowed whitespaces. If no, then does not
+# do anything. The function seeks to the original position in the file after
+# reading the text.
+sub check_allowed_whitespace {
+ my $file = shift;
+ my $filename = shift;
+ my $pos = tell($file);
+ (my $is_utf, my @lines) = get_utf8_or_bytestream($file);
+ seek($file, $pos, SEEK_SET);
+ if ($is_utf) {
+ my $has_invalid_whitespaces = validate_utf8_whitespaces(\@lines);
+ if ($has_invalid_whitespaces) {
+ print STDERR "$0: ERROR: text file '$filename' contains disallowed UTF-8 whitespace character(s)\n";
+ return 0;
+ }
+ }
+ return 1;
+}
+
+if(@ARGV != 1) {
+ die "Usage: validate_text.pl <text-file>\n" .
+ "e.g.: validate_text.pl data/train/text\n";
+}
+
+my $text = shift @ARGV;
+
+if (-z "$text") {
+ print STDERR "$0: ERROR: file '$text' is empty or does not exist\n";
+ exit 1;
+}
+
+if(!open(FILE, "<$text")) {
+ print STDERR "$0: ERROR: failed to open $text\n";
+ exit 1;
+}
+
+check_allowed_whitespace(\*FILE, $text) or exit 1;
+close(FILE);
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh
new file mode 100755
index 0000000..5721f3f
--- /dev/null
+++ b/egs/alimeeting/sa-asr/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/run.sh b/egs/alimeeting/sa-asr/run.sh
new file mode 100755
index 0000000..e5297b8
--- /dev/null
+++ b/egs/alimeeting/sa-asr/run.sh
@@ -0,0 +1,50 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+ngpu=4
+device="0,1,2,3"
+
+stage=1
+stop_stage=18
+
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far"
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+inference_config=conf/decode_asr_rnn.yaml
+
+lm_config=conf/train_lm_transformer.yaml
+use_lm=false
+use_wordlm=false
+./asr_local.sh \
+ --device ${device} \
+ --ngpu ${ngpu} \
+ --stage ${stage} \
+ --stop_stage ${stop_stage} \
+ --gpu_inference true \
+ --njob_infer 4 \
+ --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
+ --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
+ --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
+ --lang zh \
+ --audio_format wav \
+ --feats_type raw \
+ --token_type char \
+ --use_lm ${use_lm} \
+ --use_word_lm ${use_wordlm} \
+ --lm_config "${lm_config}" \
+ --asr_config "${asr_config}" \
+ --sa_asr_config "${sa_asr_config}" \
+ --inference_config "${inference_config}" \
+ --train_set "${train_set}" \
+ --valid_set "${valid_set}" \
+ --test_sets "${test_sets}" \
+ --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
new file mode 100755
index 0000000..1967864
--- /dev/null
+++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
@@ -0,0 +1,50 @@
+#!/usr/bin/env bash
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+ngpu=4
+device="0,1,2,3"
+
+stage=1
+stop_stage=4
+
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_2023_Ali_far"
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+inference_config=conf/decode_asr_rnn.yaml
+
+lm_config=conf/train_lm_transformer.yaml
+use_lm=false
+use_wordlm=false
+./asr_local_m2met_2023_infer.sh \
+ --device ${device} \
+ --ngpu ${ngpu} \
+ --stage ${stage} \
+ --stop_stage ${stop_stage} \
+ --gpu_inference true \
+ --njob_infer 4 \
+ --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
+ --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
+ --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
+ --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
+ --lang zh \
+ --audio_format wav \
+ --feats_type raw \
+ --token_type char \
+ --use_lm ${use_lm} \
+ --use_word_lm ${use_wordlm} \
+ --lm_config "${lm_config}" \
+ --asr_config "${asr_config}" \
+ --sa_asr_config "${sa_asr_config}" \
+ --inference_config "${inference_config}" \
+ --train_set "${train_set}" \
+ --valid_set "${valid_set}" \
+ --test_sets "${test_sets}" \
+ --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa-asr/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs_modelscope/asr/TEMPLATE/README.md b/egs_modelscope/asr/TEMPLATE/README.md
index 83c462d..7ff04eb 100644
--- a/egs_modelscope/asr/TEMPLATE/README.md
+++ b/egs_modelscope/asr/TEMPLATE/README.md
@@ -1,7 +1,7 @@
# Speech Recognition
> **Note**:
-> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take the typic models as examples to demonstrate the usage.
+> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take the typic models as examples to demonstrate the usage.
## Inference
@@ -19,22 +19,24 @@
rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
print(rec_result)
```
-#### [Paraformer-online Model](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary)
+#### [Paraformer-online Model](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary)
```python
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
- model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
+ model_revision='v1.0.4'
)
import soundfile
speech, sample_rate = soundfile.read("example/asr_example.wav")
-param_dict = {"cache": dict(), "is_final": False}
-chunk_stride = 7680# 480ms
-# first chunk, 480ms
+chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
+param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
+chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
+# first chunk, 600ms
speech_chunk = speech[0:chunk_stride]
rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
print(rec_result)
-# next chunk, 480ms
+# next chunk, 600ms
speech_chunk = speech[chunk_stride:chunk_stride+chunk_stride]
rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
print(rec_result)
@@ -42,7 +44,7 @@
Full code of demo, please ref to [demo](https://github.com/alibaba-damo-academy/FunASR/discussions/241)
#### [UniASR Model](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary)
-There are three decoding mode for UniASR model(`fast`銆乣normal`銆乣offline`), for more model detailes, please refer to [docs](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary)
+There are three decoding mode for UniASR model(`fast`銆乣normal`銆乣offline`), for more model details, please refer to [docs](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary)
```python
decoding_model = "fast" # "fast"銆�"normal"銆�"offline"
inference_pipeline = pipeline(
@@ -59,7 +61,7 @@
Undo
#### [MFCCA Model](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary)
-For more model detailes, please refer to [docs](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary)
+For more model details, please refer to [docs](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary)
```python
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -74,15 +76,15 @@
print(rec_result)
```
-#### API-reference
-##### Define pipeline
+### API-reference
+#### Define pipeline
- `task`: `Tasks.auto_speech_recognition`
-- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
-##### Infer pipeline
+#### Infer pipeline
- `audio_in`: the input to decode, which could be:
- wav_path, `e.g.`: asr_example.wav,
- pcm_path, `e.g.`: asr_example.pcm,
@@ -100,20 +102,20 @@
### Inference with multi-thread CPUs or multi GPUs
FunASR also offer recipes [egs_modelscope/asr/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs.
-- Setting parameters in `infer.sh`
- - `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- - `data_dir`: the dataset dir needs to include `wav.scp`. If `${data_dir}/text` is also exists, CER will be computed
- - `output_dir`: output dir of the recognition results
- - `batch_size`: `64` (Default), batch size of inference on gpu
- - `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
- - `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
- - `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
- - `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
- - `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
- - `decoding_mode`: `normal` (Default), decoding mode for UniASR model(fast銆乶ormal銆乷ffline)
- - `hotword_txt`: `None` (Default), hotword file for contextual paraformer model(the hotword file name ends with .txt")
+#### Settings of `infer.sh`
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `data_dir`: the dataset dir needs to include `wav.scp`. If `${data_dir}/text` is also exists, CER will be computed
+- `output_dir`: output dir of the recognition results
+- `batch_size`: `64` (Default), batch size of inference on gpu
+- `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
+- `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
+- `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
+- `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
+- `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
+- `decoding_mode`: `normal` (Default), decoding mode for UniASR model(fast銆乶ormal銆乷ffline)
+- `hotword_txt`: `None` (Default), hotword file for contextual paraformer model(the hotword file name ends with .txt")
-- Decode with multi GPUs:
+#### Decode with multi GPUs:
```shell
bash infer.sh \
--model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
@@ -123,7 +125,7 @@
--gpu_inference true \
--gpuid_list "0,1"
```
-- Decode with multi-thread CPUs:
+#### Decode with multi-thread CPUs:
```shell
bash infer.sh \
--model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
@@ -133,7 +135,7 @@
--njob 64
```
-- Results
+#### Results
The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
deleted file mode 100644
index c68a8cd..0000000
--- a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
+++ /dev/null
@@ -1,30 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text.
- - <strong>batch_bins:</strong> # batch size
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py
similarity index 82%
rename from egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
rename to egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py
index 3594815..87bb652 100644
--- a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh
new file mode 120000
index 0000000..5e59f18
--- /dev/null
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh
@@ -0,0 +1 @@
+../../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py
similarity index 82%
rename from egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
rename to egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py
index b55b59f..3b0164a 100644
--- a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh
new file mode 120000
index 0000000..5e59f18
--- /dev/null
+++ b/egs_modelscope/asr/conformer/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh
@@ -0,0 +1 @@
+../../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
index 77b2cbd..7a6b750 100644
--- a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
@@ -16,13 +16,13 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
output_dir=output_dir_job,
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in)
+ inference_pipeline(audio_in=audio_in)
def modelscope_infer(params):
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
index 0d06377..f07f308 100644
--- a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
@@ -16,13 +16,13 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
output_dir=output_dir_job,
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in)
+ inference_pipeline(audio_in=audio_in)
def modelscope_infer(params):
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/demo.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/demo.py
new file mode 100644
index 0000000..f6026d6
--- /dev/null
+++ b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/demo.py
@@ -0,0 +1,11 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
+ model_revision='v3.0.0'
+)
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
\ No newline at end of file
diff --git a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py b/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
deleted file mode 100755
index 333b66a..0000000
--- a/egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import json
-import os
-import shutil
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_after_finetune(params):
- # prepare for decoding
- pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
- for file_name in params["required_files"]:
- if file_name == "configuration.json":
- with open(os.path.join(pretrained_model_path, file_name)) as f:
- config_dict = json.load(f)
- config_dict["model"]["am_model_name"] = params["decoding_model_name"]
- with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
- json.dump(config_dict, f, indent=4, separators=(',', ': '))
- else:
- shutil.copy(os.path.join(pretrained_model_path, file_name),
- os.path.join(params["output_dir"], file_name))
- decoding_path = os.path.join(params["output_dir"], "decode_results")
- if os.path.exists(decoding_path):
- shutil.rmtree(decoding_path)
- os.mkdir(decoding_path)
-
- # decoding
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=params["output_dir"],
- output_dir=decoding_path,
- batch_size=1
- )
- audio_in = os.path.join(params["data_dir"], "wav.scp")
- inference_pipeline(audio_in=audio_in)
-
- # computer CER if GT text is set
- text_in = os.path.join(params["data_dir"], "text")
- if text_in is not None:
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
- text_proc_file2 = os.path.join(decoding_path, "1best_recog/token_nosep")
- with open(text_proc_file, 'r') as hyp_reader:
- with open(text_proc_file2, 'w') as hyp_writer:
- for line in hyp_reader:
- new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
- hyp_writer.write(new_context+'\n')
- text_in2 = os.path.join(decoding_path, "1best_recog/ref_text_nosep")
- with open(text_in, 'r') as ref_reader:
- with open(text_in2, 'w') as ref_writer:
- for line in ref_reader:
- new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
- ref_writer.write(new_context+'\n')
-
-
- compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.sp.cer"))
- compute_wer(text_in2, text_proc_file2, os.path.join(decoding_path, "text.nosp.cer"))
-
-if __name__ == '__main__':
- params = {}
- params["modelscope_model_name"] = "NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
- params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./example_data/validation"
- params["decoding_model_name"] = "valid.acc.ave.pb"
- modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
deleted file mode 100644
index 49c0aeb..0000000
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
+++ /dev/null
@@ -1,19 +0,0 @@
-# ModelScope Model
-
-## How to infer using a pretrained Paraformer-large Model
-
-### Inference
-
-You can use the pretrain model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # Support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
- - <strong>batch_size:</strong> # Set batch size in inference.
- - <strong>param_dict:</strong> # Set the hotword list in inference.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
-
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py
new file mode 100644
index 0000000..9d08923
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/finetune.py
@@ -0,0 +1,37 @@
+import os
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+
+def modelscope_finetune(params):
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
+ # dataset split ["train", "validation"]
+ ds_dict = MsDataset.load(params.data_path)
+ kwargs = dict(
+ model=params.model,
+ model_revision="v1.0.2",
+ data_dir=ds_dict,
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
+ trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ params = modelscope_args(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", data_path="./data")
+ params.output_dir = "./checkpoint" # 妯″瀷淇濆瓨璺緞
+ params.data_path = "./example_data/" # 鏁版嵁璺緞
+ params.dataset_type = "large" # finetune contextual paraformer妯″瀷鍙兘浣跨敤large dataset
+ params.batch_bins = 200000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.0002 # 璁剧疆瀛︿範鐜�
+
+ modelscope_finetune(params)
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh
index e60f6d9..6325626 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.sh
@@ -12,7 +12,7 @@
batch_size=64
gpu_inference=true # whether to perform gpu decoding
gpuid_list="0,1" # set gpus, e.g., gpuid_list="0,1"
-njob=64 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
+njob=10 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
checkpoint_dir=
checkpoint_name="valid.cer_ctc.ave.pb"
hotword_txt=None
@@ -55,8 +55,8 @@
--audio_in ${output_dir}/split/wav.$JOB.scp \
--output_dir ${output_dir}/output.$JOB \
--batch_size ${batch_size} \
- --gpuid ${gpuid} \
- --hotword_txt ${hotword_txt}
+ --hotword_txt ${hotword_txt} \
+ --gpuid ${gpuid}
}&
done
wait
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py
index 18897b1..97e9fce 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py
@@ -19,11 +19,15 @@
os.makedirs(work_dir)
wav_file_path = os.path.join(work_dir, "wav.scp")
+ counter = 0
with codecs.open(wav_file_path, 'w') as fin:
for line in ds_dict:
+ counter += 1
wav = line["Audio:FILE"]
idx = wav.split("/")[-1].split(".")[0]
fin.writelines(idx + " " + wav + "\n")
+ if counter == 50:
+ break
audio_in = wav_file_path
inference_pipeline = pipeline(
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
new file mode 100644
index 0000000..b566454
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -0,0 +1,39 @@
+import os
+import logging
+import torch
+import soundfile
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+os.environ["MODELSCOPE_CACHE"] = "./"
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
+ model_revision='v1.0.4'
+)
+
+model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online")
+speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
+speech_length = speech.shape[0]
+
+sample_offset = 0
+chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
+stride_size = chunk_size[1] * 960
+param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
+final_result = ""
+
+for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+ if sample_offset + stride_size >= speech_length - 1:
+ stride_size = speech_length - sample_offset
+ param_dict["is_final"] = True
+ rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
+ param_dict=param_dict)
+ if len(rec_result) != 0:
+ final_result += rec_result['text'] + " "
+ print(rec_result)
+print(final_result)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
deleted file mode 100644
index c740f71..0000000
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
+++ /dev/null
@@ -1,76 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
- - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
- - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.sh`
- - <strong>model:</strong> # model name on ModelScope
- - <strong>data_dir:</strong> # the dataset dir needs to include `${data_dir}/wav.scp`. If `${data_dir}/text` is also exists, CER will be computed
- - <strong>output_dir:</strong> # result dir
- - <strong>batch_size:</strong> # batchsize of inference
- - <strong>gpu_inference:</strong> # whether to perform gpu decoding, set false for cpu decoding
- - <strong>gpuid_list:</strong> # set gpus, e.g., gpuid_list="0,1"
- - <strong>njob:</strong> # the number of jobs for CPU decoding, if `gpu_inference`=false, use CPU decoding, please set `njob`
-
-- Decode with multi GPUs:
-```shell
- bash infer.sh \
- --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
- --data_dir "./data/test" \
- --output_dir "./results" \
- --batch_size 64 \
- --gpu_inference true \
- --gpuid_list "0,1"
-```
-
-- Decode with multi-thread CPUs:
-```shell
- bash infer.sh \
- --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
- --data_dir "./data/test" \
- --output_dir "./results" \
- --gpu_inference false \
- --njob 64
-```
-
-- Results
-
-The decoding results can be found in `${output_dir}/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
-
-If you decode the SpeechIO test sets, you can use textnorm with `stage`=3, and `DETAILS.txt`, `RESULTS.txt` record the results and CER after text normalization.
-
-### Inference using local finetuned model
-
-- Modify inference related parameters in `infer_after_finetune.py`
- - <strong>modelscope_model_name: </strong> # model name on ModelScope
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
- - <strong>batch_size:</strong> # batchsize of inference
-
-- Then you can run the pipeline to finetune with:
-```python
- python infer_after_finetune.py
-```
-
-- Results
-
-The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
new file mode 120000
index 0000000..92088a2
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
@@ -0,0 +1 @@
+../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
new file mode 120000
index 0000000..0b3b38b
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
@@ -0,0 +1 @@
+../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
deleted file mode 100644
index 2d311dd..0000000
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import json
-import os
-import shutil
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.hub.snapshot_download import snapshot_download
-
-from funasr.utils.compute_wer import compute_wer
-
-def modelscope_infer_after_finetune(params):
- # prepare for decoding
-
- try:
- pretrained_model_path = snapshot_download(params["modelscope_model_name"], cache_dir=params["output_dir"])
- except BaseException:
- raise BaseException(f"Please download pretrain model from ModelScope firstly.")
- shutil.copy(os.path.join(params["output_dir"], params["decoding_model_name"]), os.path.join(pretrained_model_path, "model.pb"))
- decoding_path = os.path.join(params["output_dir"], "decode_results")
- if os.path.exists(decoding_path):
- shutil.rmtree(decoding_path)
- os.mkdir(decoding_path)
-
- # decoding
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=pretrained_model_path,
- output_dir=decoding_path,
- batch_size=params["batch_size"]
- )
- audio_in = os.path.join(params["data_dir"], "wav.scp")
- inference_pipeline(audio_in=audio_in)
-
- # computer CER if GT text is set
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/text")
- compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
-
-
-if __name__ == '__main__':
- params = {}
- params["modelscope_model_name"] = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data/test"
- params["decoding_model_name"] = "valid.acc.ave_10best.pb"
- params["batch_size"] = 64
- modelscope_infer_after_finetune(params)
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/infer.py
index d1fbca2..00be793 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/infer.py
@@ -16,14 +16,14 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch",
output_dir=output_dir_job,
batch_size=64
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in)
+ inference_pipeline(audio_in=audio_in)
def modelscope_infer(params):
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
deleted file mode 100644
index c68a8cd..0000000
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
+++ /dev/null
@@ -1,30 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text.
- - <strong>batch_bins:</strong> # batch size
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
new file mode 120000
index 0000000..92088a2
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/README.md
@@ -0,0 +1 @@
+../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py
similarity index 79%
rename from egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
rename to egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py
index 8a6c87b..2863c1a 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/demo.py
@@ -4,12 +4,12 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch",
output_dir=output_dir,
- batch_size=32,
+ batch_size=1,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
new file mode 120000
index 0000000..f05fbbb
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
@@ -0,0 +1 @@
+../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh
new file mode 120000
index 0000000..0b3b38b
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.sh
@@ -0,0 +1 @@
+../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md
new file mode 120000
index 0000000..92088a2
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/README.md
@@ -0,0 +1 @@
+../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py
similarity index 82%
rename from egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
rename to egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py
index dec7de0..f2db74e 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/demo.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
new file mode 120000
index 0000000..f05fbbb
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
@@ -0,0 +1 @@
+../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh
new file mode 120000
index 0000000..0b3b38b
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.sh
@@ -0,0 +1 @@
+../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
index 2eb9cc8..6672bbf 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -14,24 +14,26 @@
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
- model_revision='v1.0.2')
+ model_revision='v1.0.4'
+)
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
speech_length = speech.shape[0]
sample_offset = 0
-step = 4800 #300ms
-param_dict = {"cache": dict(), "is_final": False}
+chunk_size = [8, 8, 4] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
+stride_size = chunk_size[1] * 960
+param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
final_result = ""
-for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
- if sample_offset + step >= speech_length - 1:
- step = speech_length - sample_offset
+for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+ if sample_offset + stride_size >= speech_length - 1:
+ stride_size = speech_length - sample_offset
param_dict["is_final"] = True
- rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + step],
+ rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
param_dict=param_dict)
- if len(rec_result) != 0 and rec_result['text'] != "sil" and rec_result['text'] != "waiting_for_more_voice":
- final_result += rec_result['text']
- print(rec_result)
-print(final_result)
+ if len(rec_result) != 0:
+ final_result += rec_result['text'] + " "
+ print(rec_result)
+print(final_result.strip())
diff --git a/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py b/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
index df18903..f4c4fc2 100644
--- a/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
+++ b/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py b/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
index 83d6805..63bed40 100644
--- a/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
+++ b/egs_modelscope/asr/paraformerbert/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py
index c151149..862f881 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cantonese-CHS.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py
index ac73adf..d4f8d76 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cantonese-CHS.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py
index 227f4bf..347d316 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/infer.py
index 74d9764..936d6d7 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py
index 5ace7e4..f82c1f4 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_de.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py
index f8d91b8..48b4807 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_de.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py
index 49b884b..98f31b6 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py
index 57a3afd..423c503 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py
index 510f008..75e22a0 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_es.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py
index 2ec5940..cb1b4fa 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_es.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
index 040265d..e6c39c2 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
@@ -16,14 +16,14 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline",
output_dir=output_dir_job,
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
def modelscope_infer(params):
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
index 055e4eb..124d5ed 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
@@ -16,14 +16,14 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online",
output_dir=output_dir_job,
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
def modelscope_infer(params):
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py
index 6aedeea..627d132 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_fr.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py
index 2f3e833..305d990 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_fr.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/infer.py
index c54ab8c..e0d1a4d 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_he.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py
index 219c9ec..e53c37e 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_id.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py
index ad2671a..75ec783 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_id.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py
index 1a174bb..68cc41d 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ja.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py
index f15bc2d..a741e18 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ja.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py
index 618b3f6..b87bcbb 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ko.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py
index 135e8f8..9be791c 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ko.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/infer.py
index cfd869f..b3a9058 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_my.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py
index 2dcb663..4a43e7c 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_pt.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py
index aff2a9a..7029fd9 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_pt.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py
index 95f447d..3c9d364 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ru.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py
index 88c06b4..95da479 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ru.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/infer.py
index e8c5524..04b02fe 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ur.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py
index 9472104..4218f3d 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_vi.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py
index 4a844fc..355e412 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py
@@ -4,10 +4,10 @@
if __name__ == "__main__":
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_vi.wav"
output_dir = "./results"
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
+ rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py
index 40686ac..3520989 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/infer.py
index dfe934d..a3e2a00 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
index ce8988e..13d2a2e 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
@@ -16,14 +16,14 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline",
output_dir=output_dir_job,
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in)
+ inference_pipeline(audio_in=audio_in)
def modelscope_infer(params):
# prepare for multi-GPU decoding
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
index 8b4a04d..876d51c 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
@@ -16,14 +16,14 @@
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online",
output_dir=output_dir_job,
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
+ inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
def modelscope_infer(params):
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py
index 1c1e303..8ec4288 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/infer.py
index 94c1b68..3ab16ea 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/infer.py
@@ -4,11 +4,11 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online",
output_dir=output_dir,
)
- rec_result = inference_pipline(audio_in=audio_in)
+ rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
index 94144ef..83c462d 100644
--- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
+++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/README.md
@@ -1,46 +1,246 @@
-# ModelScope Model
+# Speech Recognition
-## How to finetune and infer using a pretrained Paraformer-large Model
+> **Note**:
+> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take the typic models as examples to demonstrate the usage.
-### Finetune
+## Inference
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text.
- - <strong>batch_bins:</strong> # batch size
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
+### Quick start
+#### [Paraformer Model](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
```python
- python finetune.py
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+)
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
+```
+#### [Paraformer-online Model](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/summary)
+```python
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
+ )
+import soundfile
+speech, sample_rate = soundfile.read("example/asr_example.wav")
+
+param_dict = {"cache": dict(), "is_final": False}
+chunk_stride = 7680# 480ms
+# first chunk, 480ms
+speech_chunk = speech[0:chunk_stride]
+rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
+print(rec_result)
+# next chunk, 480ms
+speech_chunk = speech[chunk_stride:chunk_stride+chunk_stride]
+rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
+print(rec_result)
+```
+Full code of demo, please ref to [demo](https://github.com/alibaba-damo-academy/FunASR/discussions/241)
+
+#### [UniASR Model](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary)
+There are three decoding mode for UniASR model(`fast`銆乣normal`銆乣offline`), for more model detailes, please refer to [docs](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary)
+```python
+decoding_model = "fast" # "fast"銆�"normal"銆�"offline"
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825',
+ param_dict={"decoding_model": decoding_model})
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
+```
+The decoding mode of `fast` and `normal` is fake streaming, which could be used for evaluating of recognition accuracy.
+Full code of demo, please ref to [demo](https://github.com/alibaba-damo-academy/FunASR/discussions/151)
+#### [RNN-T-online model]()
+Undo
+
+#### [MFCCA Model](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary)
+For more model detailes, please refer to [docs](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary)
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
+ model_revision='v3.0.0'
+)
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
```
-### Inference
+#### API-reference
+##### Define pipeline
+- `task`: `Tasks.auto_speech_recognition`
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
+- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
+- `output_dir`: `None` (Default), the output path of results if set
+- `batch_size`: `1` (Default), batch size when decoding
+##### Infer pipeline
+- `audio_in`: the input to decode, which could be:
+ - wav_path, `e.g.`: asr_example.wav,
+ - pcm_path, `e.g.`: asr_example.pcm,
+ - audio bytes stream, `e.g.`: bytes data from a microphone
+ - audio sample point锛宍e.g.`: `audio, rate = soundfile.read("asr_example_zh.wav")`, the dtype is numpy.ndarray or torch.Tensor
+ - wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`:
+ ```text
+ asr_example1 ./audios/asr_example1.wav
+ asr_example2 ./audios/asr_example2.wav
+ ```
+ In this case of `wav.scp` input, `output_dir` must be set to save the output results
+- `audio_fs`: audio sampling rate, only set when audio_in is pcm audio
+- `output_dir`: None (Default), the output path of results if set
-Or you can use the finetuned model for inference directly.
+### Inference with multi-thread CPUs or multi GPUs
+FunASR also offer recipes [egs_modelscope/asr/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs.
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
+- Setting parameters in `infer.sh`
+ - `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+ - `data_dir`: the dataset dir needs to include `wav.scp`. If `${data_dir}/text` is also exists, CER will be computed
+ - `output_dir`: output dir of the recognition results
+ - `batch_size`: `64` (Default), batch size of inference on gpu
+ - `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
+ - `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
+ - `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
+ - `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
+ - `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
+ - `decoding_mode`: `normal` (Default), decoding mode for UniASR model(fast銆乶ormal銆乷ffline)
+ - `hotword_txt`: `None` (Default), hotword file for contextual paraformer model(the hotword file name ends with .txt")
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
+- Decode with multi GPUs:
+```shell
+ bash infer.sh \
+ --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+ --data_dir "./data/test" \
+ --output_dir "./results" \
+ --batch_size 64 \
+ --gpu_inference true \
+ --gpuid_list "0,1"
```
-
-### Inference using local finetuned model
-
-- Modify inference related parameters in `infer_after_finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
-
-- Then you can run the pipeline to finetune with:
-```python
- python infer_after_finetune.py
+- Decode with multi-thread CPUs:
+```shell
+ bash infer.sh \
+ --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+ --data_dir "./data/test" \
+ --output_dir "./results" \
+ --gpu_inference false \
+ --njob 64
```
- Results
-The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
+The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
+
+If you decode the SpeechIO test sets, you can use textnorm with `stage`=3, and `DETAILS.txt`, `RESULTS.txt` record the results and CER after text normalization.
+
+
+## Finetune with pipeline
+
+### Quick start
+[finetune.py](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/finetune.py)
+```python
+import os
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+from modelscope.msdatasets.audio.asr_dataset import ASRDataset
+
+def modelscope_finetune(params):
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
+ # dataset split ["train", "validation"]
+ ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
+ kwargs = dict(
+ model=params.model,
+ data_dir=ds_dict,
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
+ trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ from funasr.utils.modelscope_param import modelscope_args
+ params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ params.output_dir = "./checkpoint" # 妯″瀷淇濆瓨璺緞
+ params.data_path = "speech_asr_aishell1_trainsets" # 鏁版嵁璺緞锛屽彲浠ヤ负modelscope涓凡涓婁紶鏁版嵁锛屼篃鍙互鏄湰鍦版暟鎹�
+ params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+ params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 50 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+
+ modelscope_finetune(params)
+```
+
+```shell
+python finetune.py &> log.txt &
+```
+
+### Finetune with your data
+
+- Modify finetune training related parameters in [finetune.py](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/finetune.py)
+ - `output_dir`: result dir
+ - `data_dir`: the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+ - `dataset_type`: for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+ - `batch_bins`: batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+ - `max_epoch`: number of training epoch
+ - `lr`: learning rate
+
+- Training data formats锛�
+```sh
+cat ./example_data/text
+BAC009S0002W0122 鑰� 瀵� 妤� 甯� 鎴� 浜� 鎶� 鍒� 浣� 鐢� 鏈� 澶� 鐨� 闄� 璐�
+BAC009S0002W0123 涔� 鎴� 涓� 鍦� 鏂� 鏀� 搴� 鐨� 鐪� 涓� 閽�
+english_example_1 hello world
+english_example_2 go swim 鍘� 娓� 娉�
+
+cat ./example_data/wav.scp
+BAC009S0002W0122 /mnt/data/wav/train/S0002/BAC009S0002W0122.wav
+BAC009S0002W0123 /mnt/data/wav/train/S0002/BAC009S0002W0123.wav
+english_example_1 /mnt/data/wav/train/S0002/english_example_1.wav
+english_example_2 /mnt/data/wav/train/S0002/english_example_2.wav
+```
+
+- Then you can run the pipeline to finetune with:
+```shell
+python finetune.py
+```
+If you want finetune with multi-GPUs, you could:
+```shell
+CUDA_VISIBLE_DEVICES=1,2 python -m torch.distributed.launch --nproc_per_node 2 finetune.py > log.txt 2>&1
+```
+## Inference with your finetuned model
+
+- Setting parameters in [egs_modelscope/asr/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/infer.sh) is the same with [docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/egs_modelscope/asr/TEMPLATE#inference-with-multi-thread-cpus-or-multi-gpus), `model` is the model name from modelscope, which you finetuned.
+
+- Decode with multi GPUs:
+```shell
+ bash infer.sh \
+ --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+ --data_dir "./data/test" \
+ --output_dir "./results" \
+ --batch_size 64 \
+ --gpu_inference true \
+ --gpuid_list "0,1" \
+ --checkpoint_dir "./checkpoint" \
+ --checkpoint_name "valid.cer_ctc.ave.pb"
+```
+- Decode with multi-thread CPUs:
+```shell
+ bash infer.sh \
+ --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+ --data_dir "./data/test" \
+ --output_dir "./results" \
+ --gpu_inference false \
+ --njob 64 \
+ --checkpoint_dir "./checkpoint" \
+ --checkpoint_name "valid.cer_ctc.ave.pb"
+```
diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py
new file mode 100644
index 0000000..2fce734
--- /dev/null
+++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py
@@ -0,0 +1,16 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
+ output_dir = None
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+ vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
+ punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
+ output_dir=output_dir
+ )
+ rec_result = inference_pipeline(audio_in=audio_in)
+ print(rec_result)
+
diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py
index 4d98a65..5bc205c 100644
--- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py
+++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py
@@ -1,16 +1,28 @@
+import os
+import shutil
+import argparse
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-if __name__ == '__main__':
- audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
- output_dir = None
+def modelscope_infer(args):
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
- model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
- punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
- ngpu=1,
+ model=args.model,
+ output_dir=args.output_dir,
+ batch_size=args.batch_size,
+ param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
)
- rec_result = inference_pipeline(audio_in=audio_in)
- print(rec_result)
+ inference_pipeline(audio_in=args.audio_in)
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', type=str, default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
+ parser.add_argument('--output_dir', type=str, default="./results/")
+ parser.add_argument('--decoding_mode', type=str, default="normal")
+ parser.add_argument('--hotword_txt', type=str, default=None)
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--gpuid', type=str, default="0")
+ args = parser.parse_args()
+ modelscope_infer(args)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
similarity index 100%
rename from egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
rename to egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
deleted file mode 100644
index 473019c..0000000
--- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import json
-import os
-import shutil
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.hub.snapshot_download import snapshot_download
-
-from funasr.utils.compute_wer import compute_wer
-
-def modelscope_infer_after_finetune(params):
- # prepare for decoding
-
- try:
- pretrained_model_path = snapshot_download(params["modelscope_model_name"], cache_dir=params["output_dir"])
- except BaseException:
- raise BaseException(f"Please download pretrain model from ModelScope firstly.")shutil.copy(os.path.join(params["output_dir"], params["decoding_model_name"]), os.path.join(pretrained_model_path, "model.pb"))
- decoding_path = os.path.join(params["output_dir"], "decode_results")
- if os.path.exists(decoding_path):
- shutil.rmtree(decoding_path)
- os.mkdir(decoding_path)
-
- # decoding
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=pretrained_model_path,
- output_dir=decoding_path,
- batch_size=params["batch_size"]
- )
- audio_in = os.path.join(params["data_dir"], "wav.scp")
- inference_pipeline(audio_in=audio_in)
-
- # computer CER if GT text is set
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
- compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
-
-
-if __name__ == '__main__':
- params = {}
- params["modelscope_model_name"] = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data/test"
- params["decoding_model_name"] = "valid.acc.ave_10best.pb"
- params["batch_size"] = 64
- modelscope_infer_after_finetune(params)
\ No newline at end of file
diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/utils b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/utils
new file mode 120000
index 0000000..3d3dd06
--- /dev/null
+++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/utils
@@ -0,0 +1 @@
+../../asr/TEMPLATE/utils
\ No newline at end of file
diff --git a/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
index ec309b2..628cdd8 100644
--- a/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
+++ b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
@@ -6,12 +6,12 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-inference_pipline = pipeline(
+inference_pipeline = pipeline(
task=Tasks.language_score_prediction,
model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
output_dir="./tmp/"
)
-rec_result = inference_pipline(text_in=inputs)
+rec_result = inference_pipeline(text_in=inputs)
print(rec_result)
diff --git a/docs/modelscope_pipeline/punc_pipeline.md b/egs_modelscope/punctuation/TEMPLATE/README.md
similarity index 61%
rename from docs/modelscope_pipeline/punc_pipeline.md
rename to egs_modelscope/punctuation/TEMPLATE/README.md
index 5618973..08814ea 100644
--- a/docs/modelscope_pipeline/punc_pipeline.md
+++ b/egs_modelscope/punctuation/TEMPLATE/README.md
@@ -1,8 +1,7 @@
# Punctuation Restoration
-# Voice Activity Detection
> **Note**:
-> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetune. Here we take the model of the punctuation model of CT-Transformer as example to demonstrate the usage.
+> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetune. Here we take the model of the punctuation model of CT-Transformer as example to demonstrate the usage.
## Inference
@@ -12,21 +11,21 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-inference_pipline = pipeline(
+inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
model_revision=None)
-rec_result = inference_pipline(text_in='example/punc_example.txt')
+rec_result = inference_pipeline(text_in='example/punc_example.txt')
print(rec_result)
```
- text浜岃繘鍒舵暟鎹紝渚嬪锛氱敤鎴风洿鎺ヤ粠鏂囦欢閲岃鍑篵ytes鏁版嵁
```python
-rec_result = inference_pipline(text_in='鎴戜滑閮芥槸鏈ㄥご浜轰笉浼氳璇濅笉浼氬姩')
+rec_result = inference_pipeline(text_in='鎴戜滑閮芥槸鏈ㄥご浜轰笉浼氳璇濅笉浼氬姩')
```
- text鏂囦欢url锛屼緥濡傦細https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt
```python
-rec_result = inference_pipline(text_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt')
+rec_result = inference_pipeline(text_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt')
```
#### [CT-Transformer Realtime model](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary)
@@ -53,15 +52,15 @@
Full code of demo, please ref to [demo](https://github.com/alibaba-damo-academy/FunASR/discussions/238)
-#### API-reference
-##### Define pipeline
+### API-reference
+#### Define pipeline
- `task`: `Tasks.punctuation`
-- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `model_revision`: `None` (Default), setting the model version
-##### Infer pipeline
+#### Infer pipeline
- `text_in`: the input to decode, which could be:
- text bytes, `e.g.`: "鎴戜滑閮芥槸鏈ㄥご浜轰笉浼氳璇濅笉浼氬姩"
- text file, `e.g.`: example/punc_example.txt
@@ -69,38 +68,37 @@
- `param_dict`: reserving the cache which is necessary in realtime mode.
### Inference with multi-thread CPUs or multi GPUs
-FunASR also offer recipes [egs_modelscope/punc/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/punc/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs. It is an offline recipe and only support offline model.
+FunASR also offer recipes [egs_modelscope/punctuation/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/punctuation/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs. It is an offline recipe and only support offline model.
-- Setting parameters in `infer.sh`
- - `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- - `data_dir`: the dataset dir needs to include `punc.txt`
- - `output_dir`: output dir of the recognition results
- - `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
- - `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
- - `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
- - `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
- - `checkpoint_name`: only used for infer finetuned models, `punc.pb` (Default), which checkpoint is used to infer
+#### Settings of `infer.sh`
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `data_dir`: the dataset dir needs to include `punc.txt`
+- `output_dir`: output dir of the recognition results
+- `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
+- `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
+- `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
+- `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
+- `checkpoint_name`: only used for infer finetuned models, `punc.pb` (Default), which checkpoint is used to infer
-- Decode with multi GPUs:
+#### Decode with multi GPUs:
```shell
bash infer.sh \
--model "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \
--data_dir "./data/test" \
--output_dir "./results" \
- --batch_size 64 \
+ --batch_size 1 \
--gpu_inference true \
--gpuid_list "0,1"
```
-- Decode with multi-thread CPUs:
+#### Decode with multi-thread CPUs:
```shell
bash infer.sh \
--model "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \
--data_dir "./data/test" \
--output_dir "./results" \
--gpu_inference false \
- --njob 64
+ --njob 1
```
-
## Finetune with pipeline
diff --git a/egs_modelscope/punctuation/TEMPLATE/infer.py b/egs_modelscope/punctuation/TEMPLATE/infer.py
new file mode 100644
index 0000000..edcefbe
--- /dev/null
+++ b/egs_modelscope/punctuation/TEMPLATE/infer.py
@@ -0,0 +1,23 @@
+import os
+import shutil
+import argparse
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+def modelscope_infer(args):
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
+ inference_pipeline = pipeline(
+ task=Tasks.punctuation,
+ model=args.model,
+ output_dir=args.output_dir,
+ )
+ inference_pipeline(text_in=args.text_in)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', type=str, default="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
+ parser.add_argument('--text_in', type=str, default="./data/test/punc.txt")
+ parser.add_argument('--output_dir', type=str, default="./results/")
+ parser.add_argument('--gpuid', type=str, default="0")
+ args = parser.parse_args()
+ modelscope_infer(args)
\ No newline at end of file
diff --git a/egs_modelscope/punctuation/TEMPLATE/infer.sh b/egs_modelscope/punctuation/TEMPLATE/infer.sh
new file mode 100644
index 0000000..0af502e
--- /dev/null
+++ b/egs_modelscope/punctuation/TEMPLATE/infer.sh
@@ -0,0 +1,66 @@
+#!/usr/bin/env bash
+
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=2
+model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+data_dir="./data/test"
+output_dir="./results"
+gpu_inference=true # whether to perform gpu decoding
+gpuid_list="0,1" # set gpus, e.g., gpuid_list="0,1"
+njob=64 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
+checkpoint_dir=
+checkpoint_name="punc.pb"
+
+. utils/parse_options.sh || exit 1;
+
+if ${gpu_inference} == "true"; then
+ nj=$(echo $gpuid_list | awk -F "," '{print NF}')
+else
+ nj=$njob
+ gpuid_list=""
+ for JOB in $(seq ${nj}); do
+ gpuid_list=$gpuid_list"-1,"
+ done
+fi
+
+mkdir -p $output_dir/split
+split_scps=""
+for JOB in $(seq ${nj}); do
+ split_scps="$split_scps $output_dir/split/text.$JOB.scp"
+done
+perl utils/split_scp.pl ${data_dir}/punc.txt ${split_scps}
+
+if [ -n "${checkpoint_dir}" ]; then
+ python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
+ model=${checkpoint_dir}/${model}
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
+ echo "Decoding ..."
+ gpuid_list_array=(${gpuid_list//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+ mkdir -p ${output_dir}/output.$JOB
+ python infer.py \
+ --model ${model} \
+ --text_in ${output_dir}/split/text.$JOB.scp \
+ --output_dir ${output_dir}/output.$JOB \
+ --gpuid ${gpuid}
+ }&
+ done
+ wait
+
+ mkdir -p ${output_dir}/final_res
+ if [ -f "${output_dir}/output.1/infer.out" ]; then
+ for i in $(seq "${nj}"); do
+ cat "${output_dir}/output.${i}/infer.out"
+ done | sort -k1 >"${output_dir}/final_res/infer.out"
+ fi
+fi
+
diff --git a/egs_modelscope/punctuation/TEMPLATE/utils b/egs_modelscope/punctuation/TEMPLATE/utils
new file mode 120000
index 0000000..dc7d417
--- /dev/null
+++ b/egs_modelscope/punctuation/TEMPLATE/utils
@@ -0,0 +1 @@
+../../../egs/aishell/transformer/utils
\ No newline at end of file
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py
similarity index 100%
rename from egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
rename to egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md
deleted file mode 100644
index b125d48..0000000
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md
+++ /dev/null
@@ -1,19 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained ModelScope Model
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-task=Tasks.punctuation,
- model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
-
-- Setting parameters in `modelscope_common_infer.sh`
- - <strong>model:</strong> damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch # pre-trained model, download from modelscope
- - <strong>text_in:</strong> input path, text or url
- - <strong>output_dir:</strong> the result dir
-- Then you can run the pipeline to infer with:
-```sh
- python ./infer.py
-```
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md
new file mode 120000
index 0000000..92088a2
--- /dev/null
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/README.md
@@ -0,0 +1 @@
+../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt
deleted file mode 100644
index 367be79..0000000
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-1 璺ㄥ娌虫祦鏄吇鑲叉部宀镐汉姘戠殑鐢熷懡涔嬫簮闀挎湡浠ユ潵涓哄府鍔╀笅娓稿湴鍖洪槻鐏惧噺鐏句腑鏂规妧鏈汉鍛樺湪涓婃父鍦板尯鏋佷负鎭跺姡鐨勮嚜鐒舵潯浠朵笅鍏嬫湇宸ㄥぇ鍥伴毦鐢氳嚦鍐掔潃鐢熷懡鍗遍櫓鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒跺嚒鏄腑鏂硅兘鍋氱殑鎴戜滑閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛﹁鍒掑拰璁鸿瘉鍏奸【涓婁笅娓哥殑鍒╃泭
-2 浠庡瓨鍌ㄤ笂鏉ヨ浠呬粎鏄叏鏅浘鐗囧畠灏变細鏄浘鐗囩殑鍥涘�嶇殑瀹归噺鐒跺悗鍏ㄦ櫙鐨勮棰戜細鏄櫘閫氳棰戝叓鍊嶇殑杩欎釜瀛樺偍鐨勫瑕佹眰鑰屼笁d鐨勬ā鍨嬩細鏄浘鐗囩殑鍗佸�嶈繖閮藉鎴戜滑浠婂ぉ杩愯鍦ㄧ殑浜戣绠楃殑骞冲彴瀛樺偍鐨勫钩鍙版彁鍑轰簡鏇撮珮鐨勮姹�
-3 閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py
similarity index 89%
rename from egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
rename to egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py
index 0da8d25..20994d3 100644
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py
@@ -12,12 +12,12 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-inference_pipline = pipeline(
+inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
model_revision="v1.1.7",
output_dir="./tmp/"
)
-rec_result = inference_pipline(text_in=inputs)
+rec_result = inference_pipeline(text_in=inputs)
print(rec_result)
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
new file mode 120000
index 0000000..f05fbbb
--- /dev/null
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
@@ -0,0 +1 @@
+../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.sh b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.sh
new file mode 120000
index 0000000..0b3b38b
--- /dev/null
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.sh
@@ -0,0 +1 @@
+../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/speaker_diarization/TEMPLATE/README.md b/egs_modelscope/speaker_diarization/TEMPLATE/README.md
index 2cd702c..ba179ed 100644
--- a/egs_modelscope/speaker_diarization/TEMPLATE/README.md
+++ b/egs_modelscope/speaker_diarization/TEMPLATE/README.md
@@ -2,7 +2,7 @@
> **Note**:
> The modelscope pipeline supports all the models in
-[model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope)
+[model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope)
to inference and finetine. Here we take the model of xvector_sv as example to demonstrate the usage.
## Inference with pipeline
@@ -37,10 +37,10 @@
print(results)
```
-#### API-reference
-##### Define pipeline
+### API-reference
+#### Define pipeline
- `task`: `Tasks.speaker_diarization`
-- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
@@ -50,7 +50,7 @@
- vad format: spk1: [1.0, 3.0], [5.0, 8.0]
- rttm format: "SPEAKER test1 0 1.00 2.00 <NA> <NA> spk1 <NA> <NA>" and "SPEAKER test1 0 5.00 3.00 <NA> <NA> spk1 <NA> <NA>"
-##### Infer pipeline for speaker embedding extraction
+#### Infer pipeline for speaker embedding extraction
- `audio_in`: the input to process, which could be:
- list of url: `e.g.`: waveform files at a website
- list of local file path: `e.g.`: path/to/a.wav
diff --git a/egs_modelscope/speaker_verification/TEMPLATE/README.md b/egs_modelscope/speaker_verification/TEMPLATE/README.md
index 957da90..d6736e3 100644
--- a/egs_modelscope/speaker_verification/TEMPLATE/README.md
+++ b/egs_modelscope/speaker_verification/TEMPLATE/README.md
@@ -2,7 +2,7 @@
> **Note**:
> The modelscope pipeline supports all the models in
-[model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope)
+[model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope)
to inference and finetine. Here we take the model of xvector_sv as example to demonstrate the usage.
## Inference with pipeline
@@ -47,17 +47,17 @@
```
Full code of demo, please ref to [infer.py](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py).
-#### API-reference
-##### Define pipeline
+### API-reference
+#### Define pipeline
- `task`: `Tasks.speaker_verification`
-- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
- `sv_threshold`: `0.9465` (Default), the similarity threshold to determine
whether utterances belong to the same speaker (it should be in (0, 1))
-##### Infer pipeline for speaker embedding extraction
+#### Infer pipeline for speaker embedding extraction
- `audio_in`: the input to process, which could be:
- url (str): `e.g.`: https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav
- local_path: `e.g.`: path/to/a.wav
@@ -71,7 +71,7 @@
- fbank1.scp,speech,kaldi_ark: `e.g.`: extracted 80-dimensional fbank features
with kaldi toolkits.
-##### Infer pipeline for speaker verification
+#### Infer pipeline for speaker verification
- `audio_in`: the input to process, which could be:
- Tuple(url1, url2): `e.g.`: (https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav, https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav)
- Tuple(local_path1, local_path2): `e.g.`: (path/to/a.wav, path/to/b.wav)
diff --git a/egs_modelscope/tp/TEMPLATE/README.md b/egs_modelscope/tp/TEMPLATE/README.md
index 2678a7f..7cc8508 100644
--- a/egs_modelscope/tp/TEMPLATE/README.md
+++ b/egs_modelscope/tp/TEMPLATE/README.md
@@ -8,12 +8,12 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-inference_pipline = pipeline(
+inference_pipeline = pipeline(
task=Tasks.speech_timestamp,
model='damo/speech_timestamp_prediction-v1-16k-offline',
output_dir=None)
-rec_result = inference_pipline(
+rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
text_in='涓� 涓� 涓� 澶� 骞� 娲� 鍥� 瀹� 涓� 浠� 涔� 璺� 鍒� 瑗� 澶� 骞� 娲� 鏉� 浜� 鍛�',)
print(rec_result)
@@ -23,15 +23,15 @@
-#### API-reference
-##### Define pipeline
+### API-reference
+#### Define pipeline
- `task`: `Tasks.speech_timestamp`
-- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
-##### Infer pipeline
+#### Infer pipeline
- `audio_in`: the input speech to predict, which could be:
- wav_path, `e.g.`: asr_example.wav (wav in local or url),
- wav.scp, kaldi style wav list (`wav_id wav_path`), `e.g.`:
@@ -59,37 +59,37 @@
```
### Inference with multi-thread CPUs or multi GPUs
-FunASR also offer recipes [egs_modelscope/vad/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/vad/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs.
+FunASR also offer recipes [egs_modelscope/tp/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/tp/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs.
-- Setting parameters in `infer.sh`
- - `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- - `data_dir`: the dataset dir **must** include `wav.scp` and `text.scp`
- - `output_dir`: output dir of the recognition results
- - `batch_size`: `64` (Default), batch size of inference on gpu
- - `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
- - `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
- - `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
- - `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
- - `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
+#### Settings of `infer.sh`
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `data_dir`: the dataset dir **must** include `wav.scp` and `text.txt`
+- `output_dir`: output dir of the recognition results
+- `batch_size`: `64` (Default), batch size of inference on gpu
+- `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
+- `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
+- `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
+- `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
+- `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
-- Decode with multi GPUs:
+#### Decode with multi GPUs:
```shell
bash infer.sh \
--model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
--data_dir "./data/test" \
--output_dir "./results" \
- --batch_size 64 \
+ --batch_size 1 \
--gpu_inference true \
--gpuid_list "0,1"
```
-- Decode with multi-thread CPUs:
+#### Decode with multi-thread CPUs:
```shell
bash infer.sh \
--model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
--data_dir "./data/test" \
--output_dir "./results" \
--gpu_inference false \
- --njob 64
+ --njob 1
```
## Finetune with pipeline
diff --git a/egs_modelscope/tp/TEMPLATE/infer.py b/egs_modelscope/tp/TEMPLATE/infer.py
deleted file mode 120000
index df5dff2..0000000
--- a/egs_modelscope/tp/TEMPLATE/infer.py
+++ /dev/null
@@ -1 +0,0 @@
-../speech_timestamp_prediction-v1-16k-offline/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py b/egs_modelscope/tp/TEMPLATE/infer.py
similarity index 100%
rename from egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py
rename to egs_modelscope/tp/TEMPLATE/infer.py
diff --git a/egs_modelscope/tp/TEMPLATE/infer.sh b/egs_modelscope/tp/TEMPLATE/infer.sh
index 2a923bb..bae62e8 100644
--- a/egs_modelscope/tp/TEMPLATE/infer.sh
+++ b/egs_modelscope/tp/TEMPLATE/infer.sh
@@ -37,7 +37,7 @@
split_texts="$split_texts $output_dir/split/text.$JOB.scp"
done
perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
-perl utils/split_scp.pl ${data_dir}/text.scp ${split_texts}
+perl utils/split_scp.pl ${data_dir}/text.txt ${split_texts}
if [ -n "${checkpoint_dir}" ]; then
python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
deleted file mode 100644
index 5488aaa..0000000
--- a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
+++ /dev/null
@@ -1,25 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained ModelScope Model
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>text_in:</strong> # support text, text url.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
-
-
-Modify inference related parameters in vad.yaml.
-
-- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
-- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
- - The value tends to -1, the greater probability of noise being judged as speech
- - The value tends to 1, the greater probability of speech being judged as noise
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
new file mode 100644
index 0000000..bcc5128
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
@@ -0,0 +1,12 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.speech_timestamp,
+ model='damo/speech_timestamp_prediction-v1-16k-offline',
+ output_dir=None)
+
+rec_result = inference_pipeline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
+ text_in='涓� 涓� 涓� 澶� 骞� 娲� 鍥� 瀹� 涓� 浠� 涔� 璺� 鍒� 瑗� 澶� 骞� 娲� 鏉� 浜� 鍛�',)
+print(rec_result)
\ No newline at end of file
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.sh b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.sh
new file mode 120000
index 0000000..5e59f18
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.sh
@@ -0,0 +1 @@
+../../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/vad/TEMPLATE/README.md b/egs_modelscope/vad/TEMPLATE/README.md
index 6f746d5..4c6f8c2 100644
--- a/egs_modelscope/vad/TEMPLATE/README.md
+++ b/egs_modelscope/vad/TEMPLATE/README.md
@@ -1,7 +1,7 @@
# Voice Activity Detection
> **Note**:
-> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetune. Here we take the model of FSMN-VAD as example to demonstrate the usage.
+> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetune. Here we take the model of FSMN-VAD as example to demonstrate the usage.
## Inference
@@ -43,15 +43,15 @@
-#### API-reference
-##### Define pipeline
+### API-reference
+#### Define pipeline
- `task`: `Tasks.voice_activity_detection`
-- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
-##### Infer pipeline
+#### Infer pipeline
- `audio_in`: the input to decode, which could be:
- wav_path, `e.g.`: asr_example.wav,
- pcm_path, `e.g.`: asr_example.pcm,
@@ -69,35 +69,35 @@
### Inference with multi-thread CPUs or multi GPUs
FunASR also offer recipes [egs_modelscope/vad/TEMPLATE/infer.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/vad/TEMPLATE/infer.sh) to decode with multi-thread CPUs, or multi GPUs.
-- Setting parameters in `infer.sh`
- - `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- - `data_dir`: the dataset dir needs to include `wav.scp`
- - `output_dir`: output dir of the recognition results
- - `batch_size`: `64` (Default), batch size of inference on gpu
- - `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
- - `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
- - `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
- - `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
- - `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
+#### Settings of `infer.sh`
+- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
+- `data_dir`: the dataset dir needs to include `wav.scp`
+- `output_dir`: output dir of the recognition results
+- `batch_size`: `64` (Default), batch size of inference on gpu
+- `gpu_inference`: `true` (Default), whether to perform gpu decoding, set false for CPU inference
+- `gpuid_list`: `0,1` (Default), which gpu_ids are used to infer
+- `njob`: only used for CPU inference (`gpu_inference`=`false`), `64` (Default), the number of jobs for CPU decoding
+- `checkpoint_dir`: only used for infer finetuned models, the path dir of finetuned models
+- `checkpoint_name`: only used for infer finetuned models, `valid.cer_ctc.ave.pb` (Default), which checkpoint is used to infer
-- Decode with multi GPUs:
+#### Decode with multi GPUs:
```shell
bash infer.sh \
--model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
--data_dir "./data/test" \
--output_dir "./results" \
- --batch_size 64 \
+ --batch_size 1 \
--gpu_inference true \
--gpuid_list "0,1"
```
-- Decode with multi-thread CPUs:
+#### Decode with multi-thread CPUs:
```shell
bash infer.sh \
--model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
--data_dir "./data/test" \
--output_dir "./results" \
--gpu_inference false \
- --njob 64
+ --njob 1
```
## Finetune with pipeline
diff --git a/egs_modelscope/vad/TEMPLATE/infer.sh b/egs_modelscope/vad/TEMPLATE/infer.sh
index 7dc0387..0651c98 100644
--- a/egs_modelscope/vad/TEMPLATE/infer.sh
+++ b/egs_modelscope/vad/TEMPLATE/infer.sh
@@ -9,7 +9,7 @@
model="damo/speech_fsmn_vad_zh-cn-16k-common"
data_dir="./data/test"
output_dir="./results"
-batch_size=64
+batch_size=1
gpu_inference=true # whether to perform gpu decoding
gpuid_list="0,1" # set gpus, e.g., gpuid_list="0,1"
njob=64 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md
deleted file mode 100644
index 6d9cd30..0000000
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained ModelScope Model
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
-
-
-Modify inference related parameters in vad.yaml.
-
-- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
-- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
- - The value tends to -1, the greater probability of noise being judged as speech
- - The value tends to 1, the greater probability of speech being judged as noise
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo.py
similarity index 82%
rename from egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
rename to egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo.py
index 2bf3251..bbc16c5 100644
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo.py
@@ -4,12 +4,12 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.voice_activity_detection,
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
model_revision='v1.2.0',
output_dir=output_dir,
batch_size=1,
)
- segments_result = inference_pipline(audio_in=audio_in)
+ segments_result = inference_pipeline(audio_in=audio_in)
print(segments_result)
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo_online.py
similarity index 89%
rename from egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
rename to egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo_online.py
index 02e919d..65693b5 100644
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/demo_online.py
@@ -8,7 +8,7 @@
if __name__ == '__main__':
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.voice_activity_detection,
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
model_revision='v1.2.0',
@@ -30,7 +30,7 @@
else:
is_final = False
param_dict['is_final'] = is_final
- segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
+ segments_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + step],
param_dict=param_dict)
print(segments_result)
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.sh b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.sh
new file mode 120000
index 0000000..5e59f18
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer.sh
@@ -0,0 +1 @@
+../../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md
deleted file mode 100644
index 6d9cd30..0000000
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained ModelScope Model
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
-
-
-Modify inference related parameters in vad.yaml.
-
-- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
-- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
- - The value tends to -1, the greater probability of noise being judged as speech
- - The value tends to 1, the greater probability of speech being judged as noise
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo.py
similarity index 82%
rename from egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
rename to egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo.py
index 2e50275..84863d0 100644
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo.py
@@ -4,12 +4,12 @@
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example_8k.wav'
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.voice_activity_detection,
model="damo/speech_fsmn_vad_zh-cn-8k-common",
model_revision='v1.2.0',
output_dir=output_dir,
batch_size=1,
)
- segments_result = inference_pipline(audio_in=audio_in)
+ segments_result = inference_pipeline(audio_in=audio_in)
print(segments_result)
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo_online.py
similarity index 89%
rename from egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
rename to egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo_online.py
index a8cc912..5b67da7 100644
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/demo_online.py
@@ -8,7 +8,7 @@
if __name__ == '__main__':
output_dir = None
- inference_pipline = pipeline(
+ inference_pipeline = pipeline(
task=Tasks.voice_activity_detection,
model="damo/speech_fsmn_vad_zh-cn-8k-common",
model_revision='v1.2.0',
@@ -30,7 +30,7 @@
else:
is_final = False
param_dict['is_final'] = is_final
- segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
+ segments_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + step],
param_dict=param_dict)
print(segments_result)
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.sh b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.sh
new file mode 120000
index 0000000..5e59f18
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer.sh
@@ -0,0 +1 @@
+../../TEMPLATE/infer.sh
\ No newline at end of file
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 4722602..a52e94a 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -41,6 +41,7 @@
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -92,7 +93,11 @@
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ if asr_train_args.frontend=='wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ else:
+ frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +116,7 @@
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
@@ -193,7 +198,7 @@
"""
assert check_argument_types()
-
+
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -280,6 +285,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@@ -310,6 +316,7 @@
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
+ mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -342,6 +349,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@@ -355,6 +363,9 @@
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +419,7 @@
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
+ mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
@@ -416,7 +428,7 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -452,7 +464,7 @@
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
@@ -463,6 +475,9 @@
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}\n".format(text))
return asr_result_list
return _forward
@@ -637,4 +652,4 @@
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e10ebf4..9a1ffe5 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -71,7 +71,13 @@
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
+ group.add_argument(
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@@ -288,6 +294,9 @@
if mode == "asr":
from funasr.bin.asr_inference import inference
return inference(**kwargs)
+ elif mode == "sa_asr":
+ from funasr.bin.sa_asr_inference import inference
+ return inference(**kwargs)
elif mode == "uniasr":
from funasr.bin.asr_inference_uniasr import inference
return inference(**kwargs)
@@ -342,4 +351,4 @@
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 5546c92..5335860 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -41,6 +41,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.bin.tp_inference import SpeechText2Timestamp
@@ -236,7 +237,7 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- if not isinstance(self.asr_model, ContextualParaformer):
+ if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
if self.hotword_list:
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index ff8bb8c..4f04d02 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -20,6 +20,7 @@
import numpy as np
import torch
+import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -204,9 +205,12 @@
results = []
cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
+ if cache_en["start_idx"] == 0:
+ return []
cache_en["tail_chunk"] = True
feats = cache_en["feats"]
feats_len = torch.tensor([feats.shape[1]])
+ self.asr_model.frontend = None
results = self.infer(feats, feats_len, cache)
return results
else:
@@ -235,7 +239,7 @@
feats_len = torch.tensor([feats_chunk2.shape[1]])
results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
- return ["".join(results_chunk1 + results_chunk2)]
+ return [" ".join(results_chunk1 + results_chunk2)]
results = self.infer(feats, feats_len, cache)
@@ -295,12 +299,9 @@
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
+ token = " ".join(token)
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append(text)
+ results.append(token)
# assert check_return_type(results)
return results
@@ -515,6 +516,8 @@
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
@@ -531,13 +534,32 @@
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
- input_lens = torch.tensor([raw_inputs.shape[1]])
asr_result_list = []
-
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
- cache["encoder"]["is_final"] = is_final
- asr_result = speech2text(cache, raw_inputs, input_lens)
- item = {'key': "utt", 'value': asr_result}
+ item = {}
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ sample_offset = 0
+ speech_length = raw_inputs.shape[1]
+ stride_size = chunk_size[1] * 960
+ cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
+ final_result = ""
+ for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+ if sample_offset + stride_size >= speech_length - 1:
+ stride_size = speech_length - sample_offset
+ cache["encoder"]["is_final"] = True
+ else:
+ cache["encoder"]["is_final"] = False
+ input_lens = torch.tensor([stride_size])
+ asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
+ if len(asr_result) != 0:
+ final_result += " ".join(asr_result) + " "
+ item = {'key': "utt", 'value': final_result.strip()}
+ else:
+ input_lens = torch.tensor([raw_inputs.shape[1]])
+ cache["encoder"]["is_final"] = is_final
+ asr_result = speech2text(cache, raw_inputs, input_lens)
+ item = {'key': "utt", 'value': " ".join(asr_result)}
+
asr_result_list.append(item)
if is_final:
cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
@@ -725,12 +747,3 @@
if __name__ == "__main__":
main()
- # from modelscope.pipelines import pipeline
- # from modelscope.utils.constant import Tasks
- #
- # inference_16k_pipline = pipeline(
- # task=Tasks.auto_speech_recognition,
- # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
- #
- # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
- # print(rec_result)
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index d964643..bd36907 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -188,18 +188,15 @@
self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
- self._ctx = self.asr_model.encoder.get_encoder_input_size(
- self.window_size
- )
+ if self.streaming:
+ self._ctx = self.asr_model.encoder.get_encoder_input_size(
+ self.window_size
+ )
- #self.last_chunk_length = (
- # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
- #) * self.hop_length
-
- self.last_chunk_length = (
- self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
- )
- self.reset_inference_cache()
+ self.last_chunk_length = (
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ )
+ self.reset_inference_cache()
def reset_inference_cache(self) -> None:
"""Reset Speech2Text parameters."""
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index bba50da..a43472c 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -27,7 +27,8 @@
args = parse_args()
# setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+ if args.ngpu > 0:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
@@ -38,9 +39,9 @@
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
- if args.batch_size is not None:
+ if args.batch_size is not None and args.ngpu > 0:
args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
+ if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 94f7262..5c30fdb 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -83,7 +83,8 @@
finetune_configs = yaml.safe_load(f)
# set data_types
if dataset_type == "large":
- finetune_configs["dataset_conf"]["data_types"] = "sound,text"
+ if 'data_types' not in finetune_configs['dataset_conf']:
+ finetune_configs["dataset_conf"]["data_types"] = "sound,text"
finetune_configs = update_dct(configs, finetune_configs)
for key, value in finetune_configs.items():
if hasattr(args, key):
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index b2db1bf..0dc01f5 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -61,7 +61,7 @@
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
- print("start decoding!!!")
+
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
@@ -70,7 +70,7 @@
else:
precache = ""
cache = []
- data = {"text": precache + text}
+ data = {"text": precache + " " + text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
new file mode 100644
index 0000000..c894f54
--- /dev/null
+++ b/funasr/bin/sa_asr_inference.py
@@ -0,0 +1,687 @@
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
+from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.sa_asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
+
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend=='wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ else:
+ frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, None, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+ ) -> List[
+ Tuple[
+ Optional[str],
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, text_id, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if isinstance(profile, np.ndarray):
+ profile = torch.tensor(profile)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ asr_enc, _, spk_enc = self.asr_model.encode(**batch)
+ if isinstance(asr_enc, tuple):
+ asr_enc = asr_enc[0]
+ if isinstance(spk_enc, tuple):
+ spk_enc = spk_enc[0]
+ assert len(asr_enc) == 1, len(asr_enc)
+ assert len(spk_enc) == 1, len(spk_enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1: last_pos]
+ else:
+ token_int = hyp.yseq[1: last_pos].tolist()
+
+ spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
+
+ token_ori = self.converter.ids2tokens(token_int)
+ text_ori = self.tokenizer.tokens2text(token_ori)
+
+ text_ori_spklist = text_ori.split('$')
+ cur_index = 0
+ spk_choose = []
+ for i in range(len(text_ori_spklist)):
+ text_ori_split = text_ori_spklist[i]
+ n = len(text_ori_split)
+ spk_weights_local = spk_weigths[cur_index: cur_index + n]
+ cur_index = cur_index + n + 1
+ spk_weights_local = spk_weights_local.mean(dim=0)
+ spk_choose_local = spk_weights_local.argmax(-1)
+ spk_choose.append(spk_choose_local.item() + 1)
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ text_spklist = text.split('$')
+ assert len(spk_choose) == len(text_spklist)
+
+ spk_list=[]
+ for i in range(len(text_spklist)):
+ text_split = text_spklist[i]
+ n = len(text_split)
+ spk_list.append(str(spk_choose[i]) * n)
+
+ text_id = '$'.join(spk_list)
+
+ assert len(text) == len(text_id)
+
+ results.append((text, text_id, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+def inference(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ batch_size=batch_size,
+ beam_size=beam_size,
+ ngpu=ngpu,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ penalty=penalty,
+ log_level=log_level,
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ raw_inputs=raw_inputs,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ key_file=key_file,
+ word_lm_train_config=word_lm_train_config,
+ bpemodel=bpemodel,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ ngram_weight=ngram_weight,
+ nbest=nbest,
+ num_workers=num_workers,
+ mc=mc,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["text_id"][key] = text_id
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}".format(text))
+ logging.info("text_id predictions: {}\n".format(text_id))
+ return asr_result_list
+
+ return _forward
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--raw_inputs", type=list, default=None)
+ # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.5,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
new file mode 100755
index 0000000..07b9b19
--- /dev/null
+++ b/funasr/bin/sa_asr_train.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.sa_asr import ASRTask
+
+
+# for ASR Training
+def parse_args():
+ parser = ASRTask.get_parser()
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args=None, cmd=None):
+ # for ASR Training
+ ASRTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ if args.ngpu > 0:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small":
+ if args.batch_size is not None and args.ngpu > 0:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None and args.ngpu > 0:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 387b622..5fbd844 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -274,8 +274,7 @@
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
+
logging.basicConfig(
level=log_level,
@@ -286,7 +285,7 @@
device = "cuda"
else:
device = "cpu"
-
+ batch_size = 1
# 1. Set random-seed
set_all_random_seed(seed)
@@ -352,7 +351,6 @@
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
- results[i] = json.loads(results[i])
ibest_writer["text"][keys[i]] = "{}".format(results[i])
return vad_results
@@ -377,10 +375,7 @@
**kwargs,
):
assert check_argument_types()
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
+
logging.basicConfig(
level=log_level,
@@ -391,6 +386,7 @@
device = "cuda"
else:
device = "cpu"
+ batch_size = 1
# 1. Set random-seed
set_all_random_seed(seed)
@@ -466,7 +462,6 @@
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
- results[i] = json.loads(results[i])
ibest_writer["text"][keys[i]] = "{}".format(results[i])
return vad_results
diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
index 4d02620..a363309 100644
--- a/funasr/bin/vad_inference_online.py
+++ b/funasr/bin/vad_inference_online.py
@@ -156,8 +156,6 @@
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
@@ -168,7 +166,7 @@
device = "cuda"
else:
device = "cpu"
-
+ batch_size = 1
# 1. Set random-seed
set_all_random_seed(seed)
@@ -243,7 +241,6 @@
item = {'key': keys[i], 'value': results[i]}
vad_results.append(item)
if writer is not None:
- results[i] = json.loads(results[i])
ibest_writer["text"][keys[i]] = "{}".format(results[i])
return vad_results
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index b0e1b8f..8c224d8 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -101,7 +101,7 @@
if data_type == "kaldi_ark":
ark_reader = ReadHelper('ark:{}'.format(data_file))
reader_list.append(ark_reader)
- elif data_type == "text" or data_type == "sound":
+ elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
text_reader = open(data_file, "r")
reader_list.append(text_reader)
elif data_type == "none":
@@ -131,6 +131,13 @@
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
sample_dict["key"] = key
+ elif data_type == "text_hotword":
+ text = item
+ segs = text.strip().split()
+ sample_dict[data_name] = segs[1:]
+ if "key" not in sample_dict:
+ sample_dict["key"] = segs[0]
+ sample_dict['hw_tag'] = 1
else:
text = item
segs = text.strip().split()
@@ -167,14 +174,38 @@
shuffle = conf.get('shuffle', True)
data_names = conf.get("data_names", "speech,text")
data_types = conf.get("data_types", "kaldi_ark,text")
- dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
+
+ pre_hwfile = conf.get("pre_hwlist", None)
+ pre_prob = conf.get("pre_prob", 0) # unused yet
+
+ hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
+ "double_rate": conf.get("double_rate", 0.1),
+ "hotword_min_length": conf.get("hotword_min_length", 2),
+ "hotword_max_length": conf.get("hotword_max_length", 8),
+ "pre_prob": conf.get("pre_prob", 0.0)}
+
+ if pre_hwfile is not None:
+ pre_hwlist = []
+ with open(pre_hwfile, 'r') as fin:
+ for line in fin.readlines():
+ pre_hwlist.append(line.strip())
+ else:
+ pre_hwlist = None
+
+ dataset = AudioDataset(scp_lists,
+ data_names,
+ data_types,
+ frontend_conf=frontend_conf,
+ shuffle=shuffle,
+ mode=mode,
+ )
filter_conf = conf.get('filter_conf', {})
filter_fn = partial(filter, **filter_conf)
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
if "text" in data_names:
- vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
+ vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
tokenize_fn = partial(tokenize, **vocab)
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
diff --git a/funasr/datasets/large_datasets/utils/hotword_utils.py b/funasr/datasets/large_datasets/utils/hotword_utils.py
new file mode 100644
index 0000000..fccfea6
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/hotword_utils.py
@@ -0,0 +1,32 @@
+import random
+
+def sample_hotword(length,
+ hotword_min_length,
+ hotword_max_length,
+ sample_rate,
+ double_rate,
+ pre_prob,
+ pre_index=None):
+ if length < hotword_min_length:
+ return [-1]
+ if random.random() < sample_rate:
+ if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
+ return pre_index
+ if length == hotword_min_length:
+ return [0, length-1]
+ elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
+ # sample two hotwords in a sentence
+ _max_hw_length = min(hotword_max_length, length // 2)
+ # first hotword
+ start1 = random.randint(0, length // 3)
+ end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
+ # second hotword
+ start2 = random.randint(end1 + 1, length - hotword_min_length)
+ end2 = random.randint(min(length-1, start2+hotword_min_length-1), min(length-1, start2+hotword_max_length-1))
+ return [start1, end1, start2, end2]
+ else: # single hotword
+ start = random.randint(0, length - hotword_min_length)
+ end = random.randint(min(length-1, start+hotword_min_length-1), min(length-1, start+hotword_max_length-1))
+ return [start, end]
+ else:
+ return [-1]
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
index e0feac6..20ba7a3 100644
--- a/funasr/datasets/large_datasets/utils/padding.py
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -13,15 +13,16 @@
batch = {}
data_names = data[0].keys()
for data_name in data_names:
- if data_name == "key" or data_name =="sampling_rate":
+ if data_name == "key" or data_name == "sampling_rate":
continue
else:
- if data[0][data_name].dtype.kind == "i":
- pad_value = int_pad_value
- tensor_type = torch.int64
- else:
- pad_value = float_pad_value
- tensor_type = torch.float32
+ if data_name != 'hotword_indxs':
+ if data[0][data_name].dtype.kind == "i":
+ pad_value = int_pad_value
+ tensor_type = torch.int64
+ else:
+ pad_value = float_pad_value
+ tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
@@ -31,4 +32,47 @@
batch[data_name] = tensor_pad
batch[data_name + "_lengths"] = tensor_lengths
+ # DHA, EAHC NOT INCLUDED
+ if "hotword_indxs" in batch:
+ # if hotword indxs in batch
+ # use it to slice hotwords out
+ hotword_list = []
+ hotword_lengths = []
+ text = batch['text']
+ text_lengths = batch['text_lengths']
+ hotword_indxs = batch['hotword_indxs']
+ num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
+ B, t1 = text.shape
+ t1 += 1 # TODO: as parameter which is same as predictor_bias
+ ideal_attn = torch.zeros(B, t1, num_hw+1)
+ nth_hw = 0
+ for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
+ ideal_attn[b][:,-1] = 1
+ if hotword_indx[0] != -1:
+ start, end = int(hotword_indx[0]), int(hotword_indx[1])
+ hotword = one_text[start: end+1]
+ hotword_list.append(hotword)
+ hotword_lengths.append(end-start+1)
+ ideal_attn[b][start:end+1, nth_hw] = 1
+ ideal_attn[b][start:end+1, -1] = 0
+ nth_hw += 1
+ if len(hotword_indx) == 4 and hotword_indx[2] != -1:
+ # the second hotword if exist
+ start, end = int(hotword_indx[2]), int(hotword_indx[3])
+ hotword_list.append(one_text[start: end+1])
+ hotword_lengths.append(end-start+1)
+ ideal_attn[b][start:end+1, nth_hw-1] = 1
+ ideal_attn[b][start:end+1, -1] = 0
+ nth_hw += 1
+ hotword_list.append(torch.tensor([1]))
+ hotword_lengths.append(1)
+ hotword_pad = pad_sequence(hotword_list,
+ batch_first=True,
+ padding_value=0)
+ batch["hotword_pad"] = hotword_pad
+ batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
+ batch['ideal_attn'] = ideal_attn
+ del batch['hotword_indxs']
+ del batch['hotword_indxs_lengths']
+
return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index 0d2fd84..f0f0c66 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import re
import numpy as np
+from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
def forward_segment(text, seg_dict):
word_list = []
@@ -38,7 +39,8 @@
vocab=None,
seg_dict=None,
punc_dict=None,
- bpe_tokenizer=None):
+ bpe_tokenizer=None,
+ hw_config=None):
assert "text" in data
assert isinstance(vocab, dict)
text = data["text"]
@@ -53,6 +55,10 @@
text = seg_tokenize(text, seg_dict)
length = len(text)
+ if 'hw_tag' in data:
+ hotword_indxs = sample_hotword(length, **hw_config)
+ data['hotword_indxs'] = hotword_indxs
+ del data['hw_tag']
for i in range(length):
x = text[i]
if i == length-1 and "punc" in data and x.startswith("vad:"):
diff --git a/funasr/export/models/CT_Transformer.py b/funasr/export/models/CT_Transformer.py
index 932e3af..2319c4a 100644
--- a/funasr/export/models/CT_Transformer.py
+++ b/funasr/export/models/CT_Transformer.py
@@ -53,7 +53,7 @@
def get_dummy_inputs(self):
length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
@@ -130,7 +130,7 @@
def get_dummy_inputs(self):
length = 120
- text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index dc872b0..d757f7f 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -46,13 +46,15 @@
if self.normalize:
# soundfile.read normalizes data to [-1,1] if dtype is not given
array, rate = librosa.load(
- wav, sr=self.dest_sample_rate, mono=not self.always_2d
+ wav, sr=self.dest_sample_rate, mono=self.always_2d
)
else:
array, rate = librosa.load(
- wav, sr=self.dest_sample_rate, mono=not self.always_2d, dtype=self.dtype
+ wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
)
+ if array.ndim==2:
+ array=array.transpose((1, 0))
return rate, array
def get_path(self, key):
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 28df73f..3ea34c0 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -79,3 +79,49 @@
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom
+
+
+class NllLoss(nn.Module):
+ """Nll loss.
+
+ :param int size: the number of class
+ :param int padding_idx: ignored class id
+ :param bool normalize_length: normalize loss by sequence length if True
+ :param torch.nn.Module criterion: loss function
+ """
+
+ def __init__(
+ self,
+ size,
+ padding_idx,
+ normalize_length=False,
+ criterion=nn.NLLLoss(reduction='none'),
+ ):
+ """Construct an NllLoss object."""
+ super(NllLoss, self).__init__()
+ self.criterion = criterion
+ self.padding_idx = padding_idx
+ self.size = size
+ self.true_dist = None
+ self.normalize_length = normalize_length
+
+ def forward(self, x, target):
+ """Compute loss between x and target.
+
+ :param torch.Tensor x: prediction (batch, seqlen, class)
+ :param torch.Tensor target:
+ target signal masked with self.padding_id (batch, seqlen)
+ :return: scalar float value
+ :rtype torch.Tensor
+ """
+ assert x.size(2) == self.size
+ batch_size = x.size(0)
+ x = x.view(-1, self.size)
+ target = target.view(-1)
+ with torch.no_grad():
+ ignore = target == self.padding_idx # (B,)
+ total = len(target) - ignore.sum().item()
+ target = target.masked_fill(ignore, 0) # avoid -1 index
+ kl = self.criterion(x , target)
+ denom = total if self.normalize_length else batch_size
+ return kl.masked_fill(ignore, 0).sum() / denom
diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
index 5401ab2..a0fe9ea 100644
--- a/funasr/models/decoder/rnnt_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -33,6 +33,7 @@
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
+ use_embed_mask: bool = False,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
@@ -66,6 +67,15 @@
self.device = next(self.parameters()).device
self.score_cache = {}
+
+ self.use_embed_mask = use_embed_mask
+ if self.use_embed_mask:
+ self._embed_mask = SpecAug(
+ time_mask_width_range=3,
+ num_time_mask=4,
+ apply_freq_mask=False,
+ apply_time_warp=False
+ )
def forward(
self,
@@ -88,6 +98,8 @@
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
+ if self.use_embed_mask and self.training:
+ dec_embed = self._embed_mask(dec_embed, label_lens)[0]
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py
index aed7f20..45fdda8 100644
--- a/funasr/models/decoder/transformer_decoder.py
+++ b/funasr/models/decoder/transformer_decoder.py
@@ -13,6 +13,7 @@
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.attention import CosineDistanceAttention
from funasr.modules.dynamic_conv import DynamicConvolution
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
from funasr.modules.embedding import PositionalEncoding
@@ -763,4 +764,429 @@
normalize_before,
concat_after,
),
- )
\ No newline at end of file
+ )
+
+class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ spker_embedding_dim: int = 256,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_asr_output_layer: bool = True,
+ use_spk_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_asr_output_layer:
+ self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.asr_output_layer = None
+
+ if use_spk_output_layer:
+ self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
+ else:
+ self.spk_output_layer = None
+
+ self.cos_distance_att = CosineDistanceAttention()
+
+ self.decoder1 = None
+ self.decoder2 = None
+ self.decoder3 = None
+ self.decoder4 = None
+
+ def forward(
+ self,
+ asr_hs_pad: torch.Tensor,
+ spk_hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ tgt = ys_in_pad
+ # tgt_mask: (B, 1, L)
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+
+ asr_memory = asr_hs_pad
+ spk_memory = spk_hs_pad
+ memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
+ # Spk decoder
+ x = self.embed(tgt)
+
+ x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
+ x, tgt_mask, asr_memory, spk_memory, memory_mask
+ )
+ x, tgt_mask, spk_memory, memory_mask = self.decoder2(
+ x, tgt_mask, spk_memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.spk_output_layer is not None:
+ x = self.spk_output_layer(x)
+ dn, weights = self.cos_distance_att(x, profile, profile_lens)
+ # Asr decoder
+ x, tgt_mask, asr_memory, memory_mask = self.decoder3(
+ z, tgt_mask, asr_memory, memory_mask, dn
+ )
+ x, tgt_mask, asr_memory, memory_mask = self.decoder4(
+ x, tgt_mask, asr_memory, memory_mask
+ )
+
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.asr_output_layer is not None:
+ x = self.asr_output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, weights, olens
+
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ asr_memory: torch.Tensor,
+ spk_memory: torch.Tensor,
+ profile: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ x = self.embed(tgt)
+
+ if cache is None:
+ cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
+ new_cache = []
+ x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
+ x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
+ )
+ new_cache.append(x)
+ for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
+ x, tgt_mask, spk_memory, _ = decoder(
+ x, tgt_mask, spk_memory, None, cache=c
+ )
+ new_cache.append(x)
+ if self.normalize_before:
+ x = self.after_norm(x)
+ else:
+ x = x
+ if self.spk_output_layer is not None:
+ x = self.spk_output_layer(x)
+ dn, weights = self.cos_distance_att(x, profile, None)
+
+ x, tgt_mask, asr_memory, _ = self.decoder3(
+ z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
+ )
+ new_cache.append(x)
+
+ for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
+ x, tgt_mask, asr_memory, _ = decoder(
+ x, tgt_mask, asr_memory, None, cache=c
+ )
+ new_cache.append(x)
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.asr_output_layer is not None:
+ y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
+
+ return y, weights, new_cache
+
+ def score(self, ys, state, asr_enc, spk_enc, profile):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
+ logp, weights, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), weights.squeeze(), state
+
+class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ spker_embedding_dim: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ asr_num_blocks: int = 6,
+ spk_num_blocks: int = 3,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_asr_output_layer: bool = True,
+ use_spk_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ ):
+ assert check_argument_types()
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ spker_embedding_dim=spker_embedding_dim,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_asr_output_layer=use_asr_output_layer,
+ use_spk_output_layer=use_spk_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+
+ self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ self.decoder2 = repeat(
+ spk_num_blocks - 1,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+
+ self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
+ attention_dim,
+ spker_embedding_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ self.decoder4 = repeat(
+ asr_num_blocks - 1,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ if self.concat_after:
+ tgt_concat = torch.cat(
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+ )
+ x = residual + self.concat_linear1(tgt_concat)
+ else:
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+ if not self.normalize_before:
+ x = self.norm1(x)
+ z = x
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
+
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, skip), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(skip)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
+
+class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
+
+ def __init__(
+ self,
+ size,
+ d_size,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
+ self.size = size
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ self.spk_linear = nn.Linear(d_size, size, bias=False)
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ x = tgt_q
+ if self.normalize_before:
+ x = self.norm2(x)
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+ if not self.normalize_before:
+ x = self.norm2(x)
+ residual = x
+
+ if dn!=None:
+ x = x + self.spk_linear(dn)
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py
new file mode 100644
index 0000000..dc820db
--- /dev/null
+++ b/funasr/models/e2e_asr_contextual_paraformer.py
@@ -0,0 +1,372 @@
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import numpy as np
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.nets_utils import make_pad_mask, pad_list
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.e2e_asr_paraformer import Paraformer
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class NeatContextualParaformer(Paraformer):
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor = None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ target_buffer_length: int = -1,
+ inner_dim: int = 256,
+ bias_encoder_type: str = 'lstm',
+ use_decoder_embedding: bool = False,
+ crit_attn_weight: float = 0.0,
+ crit_attn_smooth: float = 0.0,
+ bias_encoder_dropout_rate: float = 0.0,
+ ):
+ assert check_argument_types()
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+ super().__init__(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
+ )
+
+ if bias_encoder_type == 'lstm':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
+ self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
+ elif bias_encoder_type == 'mean':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
+ else:
+ logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
+
+ self.target_buffer_length = target_buffer_length
+ if self.target_buffer_length > 0:
+ self.hotword_buffer = None
+ self.length_record = []
+ self.current_buffer_length = 0
+ self.use_decoder_embedding = use_decoder_embedding
+ self.crit_attn_weight = crit_attn_weight
+ if self.crit_attn_weight > 0:
+ self.attn_loss = torch.nn.L1Loss()
+ self.crit_attn_smooth = crit_attn_smooth
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ hotword_pad: torch.Tensor,
+ hotword_lengths: torch.Tensor,
+ ideal_attn: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ self.step_cur += 1
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ loss_ideal = None
+
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # 2b. Attention decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths, ideal_attn
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ if loss_ideal is not None:
+ loss = loss + loss_ideal * self.crit_attn_weight
+ stats["loss_ideal"] = loss_ideal.detach().cpu()
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def _calc_att_clas_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ hotword_pad: torch.Tensor,
+ hotword_lengths: torch.Tensor,
+ ideal_attn: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # -1. bias encoder
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hotword_pad)
+ else:
+ hw_embed = self.bias_embed(hotword_pad)
+ hw_embed, (_, _) = self.bias_encoder(hw_embed)
+ _ind = np.arange(0, hotword_pad.shape[0]).tolist()
+ selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]]
+ contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds, contextual_info)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ '''
+ if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
+ ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
+ attn_non_blank = attn[:,:,:,:-1]
+ ideal_attn_non_blank = ideal_attn[:,:,:-1]
+ loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
+ else:
+ loss_ideal = None
+ '''
+ loss_ideal = None
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+ if hw_list is None:
+ hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
+ hw_list_pad = pad_list(hw_list, 0)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
+ else:
+ hw_lengths = [len(i) for i in hw_list]
+ hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+ if self.use_decoder_embedding:
+ hw_embed = self.decoder.embed(hw_list_pad)
+ else:
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+ enforce_sorted=False)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index f8ba0f0..a5aaa6c 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -12,7 +12,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
@@ -62,7 +62,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- encoder: Encoder,
+ encoder: AbsEncoder,
decoder: RNNTDecoder,
joint_network: JointNetwork,
att_decoder: Optional[AbsAttDecoder] = None,
@@ -286,7 +286,7 @@
feats, feats_lengths = self.normalize(feats, feats_lengths)
# 4. Forward encoder
- encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -515,7 +515,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- encoder: Encoder,
+ encoder: AbsEncoder,
decoder: RNNTDecoder,
joint_network: JointNetwork,
att_decoder: Optional[AbsAttDecoder] = None,
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
new file mode 100644
index 0000000..f694cc2
--- /dev/null
+++ b/funasr/models/e2e_sa_asr.py
@@ -0,0 +1,520 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import torch
+import torch.nn.functional as F
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, NllLoss # noqa: H301
+)
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class ESPnetASRModel(AbsESPnetModel):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+ def __init__(
+ self,
+ vocab_size: int,
+ max_spk_num: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ asr_encoder: AbsEncoder,
+ spk_encoder: torch.nn.Module,
+ postencoder: Optional[AbsPostEncoder],
+ decoder: AbsDecoder,
+ ctc: CTC,
+ spk_weight: float = 0.5,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ ):
+ assert check_argument_types()
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+ super().__init__()
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = 0
+ self.sos = 1
+ self.eos = 2
+ self.vocab_size = vocab_size
+ self.max_spk_num=max_spk_num
+ self.ignore_id = ignore_id
+ self.spk_weight = spk_weight
+ self.ctc_weight = ctc_weight
+ self.interctc_weight = interctc_weight
+ self.token_list = token_list.copy()
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.preencoder = preencoder
+ self.postencoder = postencoder
+ self.asr_encoder = asr_encoder
+ self.spk_encoder = spk_encoder
+
+ if not hasattr(self.asr_encoder, "interctc_use_conditioning"):
+ self.asr_encoder.interctc_use_conditioning = False
+ if self.asr_encoder.interctc_use_conditioning:
+ self.asr_encoder.conditioning_layer = torch.nn.Linear(
+ vocab_size, self.asr_encoder.output_size()
+ )
+
+ self.error_calculator = None
+
+
+ # we set self.decoder = None in the CTC mode since
+ # self.decoder parameters were never used and PyTorch complained
+ # and threw an Exception in the multi-GPU experiment.
+ # thanks Jeff Farris for pointing out the issue.
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ self.criterion_spk = NllLoss(
+ size=max_spk_num,
+ padding_idx=ignore_id,
+ normalize_length=length_normalized_loss,
+ )
+
+ if report_cer or report_wer:
+ self.error_calculator = ErrorCalculator(
+ token_list, sym_space, sym_blank, report_cer, report_wer
+ )
+
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lengths: torch.Tensor,
+ text_id: torch.Tensor,
+ text_id_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ profile: (Batch, Length, Dim)
+ profile_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+
+ # 1. Encoder
+ asr_encoder_out, encoder_out_lens, spk_encoder_out = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(asr_encoder_out, tuple):
+ intermediate_outs = asr_encoder_out[1]
+ asr_encoder_out = asr_encoder_out[0]
+
+ loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = None, None, None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ asr_encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+
+ # 2b. Attention decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
+ asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss_asr = loss_att
+ elif self.ctc_weight == 1.0:
+ loss_asr = loss_ctc
+ else:
+ loss_asr = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+ if self.spk_weight == 0.0:
+ loss = loss_asr
+ else:
+ loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
+
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_asr=loss_asr.detach(),
+ loss_att=loss_att.detach() if loss_att is not None else None,
+ loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
+ loss_spk=loss_spk.detach() if loss_spk is not None else None,
+ acc=acc_att,
+ acc_spk=acc_spk,
+ cer=cer_att,
+ wer=wer_att,
+ cer_ctc=cer_ctc,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, torch.Tensor]:
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+ feats, feats_lengths = speech, speech_lengths
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ feats_raw = feats.clone()
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.asr_encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.asr_encoder(
+ feats, feats_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
+ # import ipdb;ipdb.set_trace()
+ if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
+ encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
+ else:
+ encoder_out_spk=encoder_out_spk_ori
+ # Post-encoder, e.g. NLU
+ if self.postencoder is not None:
+ encoder_out, encoder_out_lens = self.postencoder(
+ encoder_out, encoder_out_lens
+ )
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+ assert encoder_out_spk.size(0) == speech.size(0), (
+ encoder_out_spk.size(),
+ speech.size(0),
+ )
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, encoder_out_lens, encoder_out_spk
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def nll(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute negative log likelihood(nll) from transformer-decoder
+
+ Normally, this function is called in batchify_nll.
+
+ Args:
+ encoder_out: (Batch, Length, Dim)
+ encoder_out_lens: (Batch,)
+ ys_pad: (Batch, Length)
+ ys_pad_lens: (Batch,)
+ """
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ ) # [batch, seqlen, dim]
+ batch_size = decoder_out.size(0)
+ decoder_num_class = decoder_out.size(2)
+ # nll: negative log-likelihood
+ nll = torch.nn.functional.cross_entropy(
+ decoder_out.view(-1, decoder_num_class),
+ ys_out_pad.view(-1),
+ ignore_index=self.ignore_id,
+ reduction="none",
+ )
+ nll = nll.view(batch_size, -1)
+ nll = nll.sum(dim=1)
+ assert nll.size(0) == batch_size
+ return nll
+
+ def batchify_nll(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ batch_size: int = 100,
+ ):
+ """Compute negative log likelihood(nll) from transformer-decoder
+
+ To avoid OOM, this fuction seperate the input into batches.
+ Then call nll for each batch and combine and return results.
+ Args:
+ encoder_out: (Batch, Length, Dim)
+ encoder_out_lens: (Batch,)
+ ys_pad: (Batch, Length)
+ ys_pad_lens: (Batch,)
+ batch_size: int, samples each batch contain when computing nll,
+ you may change this to avoid OOM or increase
+ GPU memory usage
+ """
+ total_num = encoder_out.size(0)
+ if total_num <= batch_size:
+ nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+ else:
+ nll = []
+ start_idx = 0
+ while True:
+ end_idx = min(start_idx + batch_size, total_num)
+ batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
+ batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
+ batch_ys_pad = ys_pad[start_idx:end_idx, :]
+ batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
+ batch_nll = self.nll(
+ batch_encoder_out,
+ batch_encoder_out_lens,
+ batch_ys_pad,
+ batch_ys_pad_lens,
+ )
+ nll.append(batch_nll)
+ start_idx = end_idx
+ if start_idx == total_num:
+ break
+ nll = torch.cat(nll)
+ assert nll.size(0) == total_num
+ return nll
+
+ def _calc_att_loss(
+ self,
+ asr_encoder_out: torch.Tensor,
+ spk_encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lens: torch.Tensor,
+ text_id: torch.Tensor,
+ text_id_lengths: torch.Tensor
+ ):
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, weights_no_pad, _ = self.decoder(
+ asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
+ )
+
+ spk_num_no_pad=weights_no_pad.size(-1)
+ pad=(0,self.max_spk_num-spk_num_no_pad)
+ weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
+
+ # pre_id=weights.argmax(-1)
+ # pre_text=decoder_out.argmax(-1)
+ # id_mask=(pre_id==text_id).to(dtype=text_id.dtype)
+ # pre_text_mask=pre_text*id_mask+1-id_mask #鐩稿悓鐨勫湴鏂逛笉鍙橈紝涓嶅悓鐨勫湴鏂硅涓�1(<unk>)
+ # padding_mask= ys_out_pad != self.ignore_id
+ # numerator = torch.sum(pre_text_mask.masked_select(padding_mask) == ys_out_pad.masked_select(padding_mask))
+ # denominator = torch.sum(padding_mask)
+ # sd_acc = float(numerator) / float(denominator)
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ loss_spk = self.criterion_spk(torch.log(weights), text_id)
+
+ acc_spk= th_accuracy(
+ weights.view(-1, self.max_spk_num),
+ text_id,
+ ignore_label=self.ignore_id,
+ )
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 9777cee..434f2a4 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -307,7 +307,7 @@
feed_forward: torch.nn.Module,
feed_forward_macaron: torch.nn.Module,
conv_mod: torch.nn.Module,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_class: torch.nn.Module = LayerNorm,
norm_args: Dict = {},
dropout_rate: float = 0.0,
) -> None:
@@ -1145,7 +1145,7 @@
x = x[:,::self.time_reduction_factor,:]
olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
- return x, olens
+ return x, olens, None
def simu_chunk_forward(
self,
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 969ddad..2a68011 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -380,7 +380,7 @@
else:
xs_pad = self.embed(xs_pad, cache)
if cache["tail_chunk"]:
- xs_pad = cache["feats"]
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
else:
xs_pad = self._add_overlap_chunk(xs_pad, cache)
encoder_outs = self.encoders0(xs_pad, None, None, None, None)
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 9671fe9..2e1b0c4 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -38,6 +38,7 @@
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
+ use_channel: int = None,
):
assert check_argument_types()
super().__init__()
@@ -77,6 +78,7 @@
)
self.n_mels = n_mels
self.frontend_type = "default"
+ self.use_channel = use_channel
def output_size(self) -> int:
return self.n_mels
@@ -100,9 +102,12 @@
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
- # Select 1ch randomly
- ch = np.random.randint(input_stft.size(2))
- input_stft = input_stft[:, :, ch, :]
+ if self.use_channel == None:
+ input_stft = input_stft[:, :, 0, :]
+ else:
+ # Select 1ch randomly
+ ch = np.random.randint(input_stft.size(2))
+ input_stft = input_stft[:, :, ch, :]
else:
# Use the first channel
input_stft = input_stft[:, :, 0, :]
diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py
index 8f85de9..39d94be 100644
--- a/funasr/models/pooling/statistic_pooling.py
+++ b/funasr/models/pooling/statistic_pooling.py
@@ -83,9 +83,9 @@
num_chunk = int(math.ceil(tt / pooling_stride))
pad = pooling_size // 2
if len(xs_pad.shape) == 4:
- features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
+ features = F.pad(xs_pad, (0, 0, pad, pad), "replicate")
else:
- features = F.pad(xs_pad, (pad, pad), "reflect")
+ features = F.pad(xs_pad, (pad, pad), "replicate")
stat_list = []
for i in range(num_chunk):
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 6202079..fcb3ed4 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -13,6 +13,9 @@
from torch import nn
from typing import Optional, Tuple
+import torch.nn.functional as F
+from funasr.modules.nets_utils import make_pad_mask
+
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -959,3 +962,37 @@
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+ """ Compute Cosine Distance between spk decoder output and speaker profile
+ Args:
+ profile_path: speaker profile file path (.npy file)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, spk_decoder_out, profile, profile_lens=None):
+ """
+ Args:
+ spk_decoder_out(torch.Tensor):(B, L, D)
+ spk_profiles(torch.Tensor):(B, N, D)
+ """
+ x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
+ if profile_lens is not None:
+
+ mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+ )
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+ weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
+ else:
+ x = x[:, -1:, :, :]
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+ weights = self.softmax(weights_not_softmax) # (B, 1, N)
+ spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
+
+ return spk_embedding, weights
diff --git a/funasr/modules/beam_search/beam_search_sa_asr.py b/funasr/modules/beam_search/beam_search_sa_asr.py
new file mode 100755
index 0000000..b2b6833
--- /dev/null
+++ b/funasr/modules/beam_search/beam_search_sa_asr.py
@@ -0,0 +1,525 @@
+"""Beam search module."""
+
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.modules.e2e_asr_common import end_detect
+from funasr.modules.scorers.scorer_interface import PartialScorerInterface
+from funasr.modules.scorers.scorer_interface import ScorerInterface
+from funasr.models.decoder.abs_decoder import AbsDecoder
+
+
+class Hypothesis(NamedTuple):
+ """Hypothesis data type."""
+
+ yseq: torch.Tensor
+ spk_weigths : List
+ score: Union[float, torch.Tensor] = 0
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
+ states: Dict[str, Any] = dict()
+
+ def asdict(self) -> dict:
+ """Convert data to JSON-friendly dict."""
+ return self._replace(
+ yseq=self.yseq.tolist(),
+ score=float(self.score),
+ scores={k: float(v) for k, v in self.scores.items()},
+ )._asdict()
+
+
+class BeamSearch(torch.nn.Module):
+ """Beam search implementation."""
+
+ def __init__(
+ self,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ beam_size: int,
+ vocab_size: int,
+ sos: int,
+ eos: int,
+ token_list: List[str] = None,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = None,
+ ):
+ """Initialize beam search.
+
+ Args:
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ token_list (list[str]): List of tokens for debug log
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ """
+ super().__init__()
+ # set scorers
+ self.weights = weights
+ self.scorers = dict()
+ self.full_scorers = dict()
+ self.part_scorers = dict()
+ # this module dict is required for recursive cast
+ # `self.to(device, dtype)` in `recog.py`
+ self.nn_dict = torch.nn.ModuleDict()
+ for k, v in scorers.items():
+ w = weights.get(k, 0)
+ if w == 0 or v is None:
+ continue
+ assert isinstance(
+ v, ScorerInterface
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
+ self.scorers[k] = v
+ if isinstance(v, PartialScorerInterface):
+ self.part_scorers[k] = v
+ else:
+ self.full_scorers[k] = v
+ if isinstance(v, torch.nn.Module):
+ self.nn_dict[k] = v
+
+ # set configurations
+ self.sos = sos
+ self.eos = eos
+ self.token_list = token_list
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
+ self.beam_size = beam_size
+ self.n_vocab = vocab_size
+ if (
+ pre_beam_score_key is not None
+ and pre_beam_score_key != "full"
+ and pre_beam_score_key not in self.full_scorers
+ ):
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+ self.pre_beam_score_key = pre_beam_score_key
+ self.do_pre_beam = (
+ self.pre_beam_score_key is not None
+ and self.pre_beam_size < self.n_vocab
+ and len(self.part_scorers) > 0
+ )
+
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+ """Get an initial hypothesis data.
+
+ Args:
+ x (torch.Tensor): The encoder output feature
+
+ Returns:
+ Hypothesis: The initial hypothesis.
+
+ """
+ init_states = dict()
+ init_scores = dict()
+ for k, d in self.scorers.items():
+ init_states[k] = d.init_state(x)
+ init_scores[k] = 0.0
+ return [
+ Hypothesis(
+ score=0.0,
+ scores=init_scores,
+ states=init_states,
+ yseq=torch.tensor([self.sos], device=x.device),
+ spk_weigths=[],
+ )
+ ]
+
+ @staticmethod
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+ """Append new token to prefix tokens.
+
+ Args:
+ xs (torch.Tensor): The prefix token
+ x (int): The new token to append
+
+ Returns:
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+ """
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+ return torch.cat((xs, x))
+
+ def score_full(
+ self, hyp: Hypothesis, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.full_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.full_scorers`
+ and tensor score values of shape: `(self.n_vocab,)`,
+ and state dict that has string keys
+ and state values of `self.full_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.full_scorers.items():
+ if isinstance(d, AbsDecoder):
+ scores[k], spk_weigths, states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc, spk_enc, profile)
+ else:
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], asr_enc)
+ return scores, spk_weigths, states
+
+ def score_partial(
+ self, hyp: Hypothesis, ids: torch.Tensor, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor,
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.part_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.part_scorers`
+ and tensor score values of shape: `(len(ids),)`,
+ and state dict that has string keys
+ and state values of `self.part_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.part_scorers.items():
+ if isinstance(d, AbsDecoder):
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc, spk_enc, profile)
+ else:
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], asr_enc)
+ return scores, states
+
+ def beam(
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute topk full token ids and partial token ids.
+
+ Args:
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+ Its shape is `(self.n_vocab,)`.
+ ids (torch.Tensor): The partial token ids to compute topk
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ The topk full token ids and partial token ids.
+ Their shapes are `(self.beam_size,)`
+
+ """
+ # no pre beam performed
+ if weighted_scores.size(0) == ids.size(0):
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ return top_ids, top_ids
+
+ # mask pruned in pre-beam not to select in topk
+ tmp = weighted_scores[ids]
+ weighted_scores[:] = -float("inf")
+ weighted_scores[ids] = tmp
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+ return top_ids, local_ids
+
+ @staticmethod
+ def merge_scores(
+ prev_scores: Dict[str, float],
+ next_full_scores: Dict[str, torch.Tensor],
+ full_idx: int,
+ next_part_scores: Dict[str, torch.Tensor],
+ part_idx: int,
+ ) -> Dict[str, torch.Tensor]:
+ """Merge scores for new hypothesis.
+
+ Args:
+ prev_scores (Dict[str, float]):
+ The previous hypothesis scores by `self.scorers`
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+ full_idx (int): The next token id for `next_full_scores`
+ next_part_scores (Dict[str, torch.Tensor]):
+ scores of partial tokens by `self.part_scorers`
+ part_idx (int): The new token id for `next_part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are scalar tensors by the scorers.
+
+ """
+ new_scores = dict()
+ for k, v in next_full_scores.items():
+ new_scores[k] = prev_scores[k] + v[full_idx]
+ for k, v in next_part_scores.items():
+ new_scores[k] = prev_scores[k] + v[part_idx]
+ return new_scores
+
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+ """Merge states for new hypothesis.
+
+ Args:
+ states: states of `self.full_scorers`
+ part_states: states of `self.part_scorers`
+ part_idx (int): The new token id for `part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are states of the scorers.
+
+ """
+ new_states = dict()
+ for k, v in states.items():
+ new_states[k] = v
+ for k, d in self.part_scorers.items():
+ new_states[k] = d.select_state(part_states[k], part_idx)
+ return new_states
+
+ def search(
+ self, running_hyps: List[Hypothesis], asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor
+ ) -> List[Hypothesis]:
+ """Search new tokens for running hypotheses and encoded speech x.
+
+ Args:
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
+ x (torch.Tensor): Encoded speech feature (T, D)
+
+ Returns:
+ List[Hypotheses]: Best sorted hypotheses
+
+ """
+ # import ipdb;ipdb.set_trace()
+ best_hyps = []
+ part_ids = torch.arange(self.n_vocab, device=asr_enc.device) # no pre-beam
+ for hyp in running_hyps:
+ # scoring
+ weighted_scores = torch.zeros(self.n_vocab, dtype=asr_enc.dtype, device=asr_enc.device)
+ scores, spk_weigths, states = self.score_full(hyp, asr_enc, spk_enc, profile)
+ for k in self.full_scorers:
+ weighted_scores += self.weights[k] * scores[k]
+ # partial scoring
+ if self.do_pre_beam:
+ pre_beam_scores = (
+ weighted_scores
+ if self.pre_beam_score_key == "full"
+ else scores[self.pre_beam_score_key]
+ )
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+ part_scores, part_states = self.score_partial(hyp, part_ids, asr_enc, spk_enc, profile)
+ for k in self.part_scorers:
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+ # add previous hyp score
+ weighted_scores += hyp.score
+
+ # update hyps
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+ # will be (2 x beam at most)
+ best_hyps.append(
+ Hypothesis(
+ score=weighted_scores[j],
+ yseq=self.append_token(hyp.yseq, j),
+ scores=self.merge_scores(
+ hyp.scores, scores, j, part_scores, part_j
+ ),
+ states=self.merge_states(states, part_states, part_j),
+ spk_weigths=hyp.spk_weigths+[spk_weigths],
+ )
+ )
+
+ # sort and prune 2 x beam -> beam
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+ : min(len(best_hyps), self.beam_size)
+ ]
+ return best_hyps
+
+ def forward(
+ self, asr_enc: torch.Tensor, spk_enc: torch.Tensor, profile: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ minlenratio (float): Input length ratio to obtain min output length.
+
+ Returns:
+ list[Hypothesis]: N-best decoding results
+
+ """
+ # import ipdb;ipdb.set_trace()
+ # set length bounds
+ if maxlenratio == 0:
+ maxlen = asr_enc.shape[0]
+ else:
+ maxlen = max(1, int(maxlenratio * asr_enc.size(0)))
+ minlen = int(minlenratio * asr_enc.size(0))
+ logging.info("decoder input length: " + str(asr_enc.shape[0]))
+ logging.info("max output length: " + str(maxlen))
+ logging.info("min output length: " + str(minlen))
+
+ # main loop of prefix search
+ running_hyps = self.init_hyp(asr_enc)
+ ended_hyps = []
+ for i in range(maxlen):
+ logging.debug("position " + str(i))
+ best = self.search(running_hyps, asr_enc, spk_enc, profile)
+ #import pdb;pdb.set_trace()
+ # post process of one iteration
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+ # end detection
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+ logging.info(f"end detected at {i}")
+ break
+ if len(running_hyps) == 0:
+ logging.info("no hypothesis. Finish decoding.")
+ break
+ else:
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+ # check the number of hypotheses reaching to eos
+ if len(nbest_hyps) == 0:
+ logging.warning(
+ "there is no N-best results, perform recognition "
+ "again with smaller minlenratio."
+ )
+ return (
+ []
+ if minlenratio < 0.1
+ else self.forward(asr_enc, spk_enc, profile, maxlenratio, max(0.0, minlenratio - 0.1))
+ )
+
+ # report the best result
+ best = nbest_hyps[0]
+ for k, v in best.scores.items():
+ logging.info(
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+ )
+ logging.info(f"total log probability: {best.score:.2f}")
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+ if self.token_list is not None:
+ logging.info(
+ "best hypo: "
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ + "\n"
+ )
+ return nbest_hyps
+
+ def post_process(
+ self,
+ i: int,
+ maxlen: int,
+ maxlenratio: float,
+ running_hyps: List[Hypothesis],
+ ended_hyps: List[Hypothesis],
+ ) -> List[Hypothesis]:
+ """Perform post-processing of beam search iterations.
+
+ Args:
+ i (int): The length of hypothesis tokens.
+ maxlen (int): The maximum length of tokens in beam search.
+ maxlenratio (int): The maximum length ratio in beam search.
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+ Returns:
+ List[Hypothesis]: The new running hypotheses.
+
+ """
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+ if self.token_list is not None:
+ logging.debug(
+ "best hypo: "
+ + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+ )
+ # add eos in the final loop to avoid that there are no ended hyps
+ if i == maxlen - 1:
+ logging.info("adding <eos> in the last position in the loop")
+ running_hyps = [
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
+ for h in running_hyps
+ ]
+
+ # add ended hypotheses to a final list, and removed them from current hypotheses
+ # (this will be a problem, number of hyps < beam)
+ remained_hyps = []
+ for hyp in running_hyps:
+ if hyp.yseq[-1] == self.eos:
+ # e.g., Word LM needs to add final <eos> score
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+ s = d.final_score(hyp.states[k])
+ hyp.scores[k] += s
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+ ended_hyps.append(hyp)
+ else:
+ remained_hyps.append(hyp)
+ return remained_hyps
+
+
+def beam_search(
+ x: torch.Tensor,
+ sos: int,
+ eos: int,
+ beam_size: int,
+ vocab_size: int,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ token_list: List[str] = None,
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = "full",
+) -> list:
+ """Perform beam search with scorers.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ token_list (list[str]): List of tokens for debug log
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ minlenratio (float): Input length ratio to obtain min output length.
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ Returns:
+ list: N-best decoding results
+
+ """
+ ret = BeamSearch(
+ scorers,
+ weights,
+ beam_size=beam_size,
+ vocab_size=vocab_size,
+ pre_beam_ratio=pre_beam_ratio,
+ pre_beam_score_key=pre_beam_score_key,
+ sos=sos,
+ eos=eos,
+ token_list=token_list,
+ ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
+ return [h.asdict() for h in ret]
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 10df124..397a5c4 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -485,14 +485,39 @@
new_k = k.replace(old_prefix, new_prefix)
state_dict[new_k] = v
-
class Swish(torch.nn.Module):
- """Construct an Swish object."""
+ """Swish activation definition.
- def forward(self, x):
- """Return Swich activation function."""
- return x * torch.sigmoid(x)
+ Swish(x) = (beta * x) * sigmoid(x)
+ where beta = 1 defines standard Swish activation.
+ References:
+ https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
+ E-swish variant: https://arxiv.org/abs/1801.07145.
+
+ Args:
+ beta: Beta parameter for E-Swish.
+ (beta >= 1. If beta < 1, use standard Swish).
+ use_builtin: Whether to use PyTorch function if available.
+
+ """
+
+ def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
+ super().__init__()
+
+ self.beta = beta
+
+ if beta > 1:
+ self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
+ else:
+ if use_builtin:
+ self.swish = torch.nn.SiLU()
+ else:
+ self.swish = lambda x: x * torch.sigmoid(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward computation."""
+ return self.swish(x)
def get_activation(act):
"""Return activation function."""
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index 2b2dac8..ff1e182 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -7,7 +7,7 @@
"""Repeat the same layer definition."""
from typing import Dict, List, Optional
-
+from funasr.modules.layer_norm import LayerNorm
import torch
@@ -48,7 +48,7 @@
self,
block_list: List[torch.nn.Module],
output_size: int,
- norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_class: torch.nn.Module = LayerNorm,
) -> None:
"""Construct a MultiBlocks object."""
super().__init__()
diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md
index da92559..71bb035 100644
--- a/funasr/runtime/grpc/Readme.md
+++ b/funasr/runtime/grpc/Readme.md
@@ -1,4 +1,4 @@
-# Using funasr with grpc-cpp
+# Service with grpc-cpp
## For the Server
@@ -37,39 +37,32 @@
### Start grpc paraformer server
```
-./cmake/build/paraformer-server --port-id <string> [--punc-config
- <string>] [--punc-model <string>]
- --am-config <string> --am-cmvn <string>
- --am-model <string> [--vad-config
- <string>] [--vad-cmvn <string>]
- [--vad-model <string>] [--] [--version]
- [-h]
+
+./cmake/build/paraformer-server --port-id <string> [--punc-quant <string>]
+ [--punc-dir <string>] [--vad-quant <string>]
+ [--vad-dir <string>] [--quantize <string>]
+ --model-dir <string> [--] [--version] [-h]
Where:
--port-id <string>
(required) port id
+ --model-dir <string>
+ (required) the asr model path, which contains model.onnx, config.yaml, am.mvn
+ --quantize <string>
+ false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
- --am-config <string>
- (required) am config path
- --am-cmvn <string>
- (required) am cmvn path
- --am-model <string>
- (required) am model path
+ --vad-dir <string>
+ the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+ --vad-quant <string>
+ false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
- --punc-config <string>
- punc config path
- --punc-model <string>
- punc model path
-
- --vad-config <string>
- vad config path
- --vad-cmvn <string>
- vad cmvn path
- --vad-model <string>
- vad model path
-
- Required: --port-id <string> --am-config <string> --am-cmvn <string> --am-model <string>
- If use vad, please add: [--vad-config <string>] [--vad-cmvn <string>] [--vad-model <string>]
- If use punc, please add: [--punc-config <string>] [--punc-model <string>]
+ --punc-dir <string>
+ the punc model path, which contains model.onnx, punc.yaml
+ --punc-quant <string>
+ false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
+
+ Required: --port-id <string> --model-dir <string>
+ If use vad, please add: --vad-dir <string>
+ If use punc, please add: --punc-dir <string>
```
## For the client
diff --git a/funasr/runtime/grpc/paraformer-server.cc b/funasr/runtime/grpc/paraformer-server.cc
index 31333c9..3bc011a 100644
--- a/funasr/runtime/grpc/paraformer-server.cc
+++ b/funasr/runtime/grpc/paraformer-server.cc
@@ -31,7 +31,7 @@
using paraformer::ASR;
ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
- AsrHanlde=FunASRInit(model_path, 1);
+ AsrHanlde=FunOfflineInit(model_path, 1);
std::cout << "ASRServicer init" << std::endl;
init_flag = 0;
}
@@ -137,7 +137,7 @@
stream->Write(res);
}
else {
- FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
+ FUNASR_RESULT Result= FunOfflineRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
@@ -204,38 +204,30 @@
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
- TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
- TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
- TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
- TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
- TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
- TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
-
- TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
- TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
- TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
-
- cmd.add(vad_model);
- cmd.add(vad_cmvn);
- cmd.add(vad_config);
- cmd.add(am_model);
- cmd.add(am_cmvn);
- cmd.add(am_config);
- cmd.add(punc_model);
- cmd.add(punc_config);
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
cmd.add(port_id);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
- GetValue(vad_model, VAD_MODEL_PATH, model_path);
- GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
- GetValue(vad_config, VAD_CONFIG_PATH, model_path);
- GetValue(am_model, AM_MODEL_PATH, model_path);
- GetValue(am_cmvn, AM_CMVN_PATH, model_path);
- GetValue(am_config, AM_CONFIG_PATH, model_path);
- GetValue(punc_model, PUNC_MODEL_PATH, model_path);
- GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(port_id, PORT_ID, model_path);
RunServer(model_path);
diff --git a/funasr/runtime/grpc/paraformer-server.h b/funasr/runtime/grpc/paraformer-server.h
index 108e3b6..760ea2a 100644
--- a/funasr/runtime/grpc/paraformer-server.h
+++ b/funasr/runtime/grpc/paraformer-server.h
@@ -15,7 +15,7 @@
#include <chrono>
#include "paraformer.grpc.pb.h"
-#include "libfunasrapi.h"
+#include "funasrruntime.h"
using grpc::Server;
diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 25b816f..9f6013f 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -38,5 +38,4 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog)
set(BUILD_TESTING OFF)
add_subdirectory(third_party/glog)
-endif()
-
+endif()
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index ab9f420..1eabd3e 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -1,16 +1,17 @@
-
#ifndef AUDIO_H
#define AUDIO_H
#include <queue>
#include <stdint.h>
-#include "model.h"
+#include "vad-model.h"
+#include "offline-stream.h"
#ifndef WAV_HEADER_SIZE
#define WAV_HEADER_SIZE 44
#endif
using namespace std;
+namespace funasr {
class AudioFrame {
private:
@@ -54,9 +55,11 @@
int FetchChunck(float *&dout, int len);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
- void Split(Model* recog_obj);
+ void Split(OfflineStream* offline_streamj);
+ void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
float GetTimeLen();
int GetQueueSize() { return (int)frame_queue.size(); }
};
+} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 9b7b212..7a6345b 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -1,7 +1,6 @@
+#pragma once
-#ifndef COMDEFINE_H
-#define COMDEFINE_H
-
+namespace funasr {
#define S_BEGIN 0
#define S_MIDDLE 1
#define S_END 2
@@ -12,19 +11,36 @@
#define MODEL_SAMPLE_RATE 16000
#endif
-// model path
-#define VAD_MODEL_PATH "vad-model"
-#define VAD_CMVN_PATH "vad-cmvn"
-#define VAD_CONFIG_PATH "vad-config"
-#define AM_MODEL_PATH "am-model"
-#define AM_CMVN_PATH "am-cmvn"
-#define AM_CONFIG_PATH "am-config"
-#define PUNC_MODEL_PATH "punc-model"
-#define PUNC_CONFIG_PATH "punc-config"
+// parser option
+#define MODEL_DIR "model-dir"
+#define VAD_DIR "vad-dir"
+#define PUNC_DIR "punc-dir"
+#define QUANTIZE "quantize"
+#define VAD_QUANT "vad-quant"
+#define PUNC_QUANT "punc-quant"
+
#define WAV_PATH "wav-path"
#define WAV_SCP "wav-scp"
+#define TXT_PATH "txt-path"
#define THREAD_NUM "thread-num"
#define PORT_ID "port-id"
+
+// #define VAD_MODEL_PATH "vad-model"
+// #define VAD_CMVN_PATH "vad-cmvn"
+// #define VAD_CONFIG_PATH "vad-config"
+// #define AM_MODEL_PATH "am-model"
+// #define AM_CMVN_PATH "am-cmvn"
+// #define AM_CONFIG_PATH "am-config"
+// #define PUNC_MODEL_PATH "punc-model"
+// #define PUNC_CONFIG_PATH "punc-config"
+
+#define MODEL_NAME "model.onnx"
+#define QUANT_MODEL_NAME "model_quant.onnx"
+#define VAD_CMVN_NAME "vad.mvn"
+#define VAD_CONFIG_NAME "vad.yaml"
+#define AM_CMVN_NAME "am.mvn"
+#define AM_CONFIG_NAME "config.yaml"
+#define PUNC_CONFIG_NAME "punc.yaml"
// vad
#ifndef VAD_SILENCE_DURATION
@@ -60,4 +76,4 @@
#define DUN_INDEX 5
#define CACHE_POP_TRIGGER_LIMIT 200
-#endif
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
new file mode 100644
index 0000000..75be80e
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -0,0 +1,88 @@
+#pragma once
+#include <map>
+#include <vector>
+
+#ifdef WIN32
+#ifdef _FUNASR_API_EXPORT
+#define _FUNASRAPI __declspec(dllexport)
+#else
+#define _FUNASRAPI __declspec(dllimport)
+#endif
+#else
+#define _FUNASRAPI
+#endif
+
+#ifndef _WIN32
+#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
+#else
+#define FUNASR_CALLBCK_PREFIX __stdcall
+#endif
+
+#ifdef __cplusplus
+
+extern "C" {
+#endif
+
+typedef void* FUNASR_HANDLE;
+typedef void* FUNASR_RESULT;
+typedef unsigned char FUNASR_BOOL;
+
+#define FUNASR_TRUE 1
+#define FUNASR_FALSE 0
+#define QM_DEFAULT_THREAD_NUM 4
+
+typedef enum
+{
+ RASR_NONE=-1,
+ RASRM_CTC_GREEDY_SEARCH=0,
+ RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
+ RASRM_ATTENSION_RESCORING = 2,
+}FUNASR_MODE;
+
+typedef enum {
+ FUNASR_MODEL_PADDLE = 0,
+ FUNASR_MODEL_PADDLE_2 = 1,
+ FUNASR_MODEL_K2 = 2,
+ FUNASR_MODEL_PARAFORMER = 3,
+}FUNASR_MODEL_TYPE;
+
+typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
+
+// ASR
+_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
+
+_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+
+_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
+_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
+_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
+_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
+
+// VAD
+_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
+
+_FUNASRAPI FUNASR_RESULT FsmnVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI std::vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result);
+_FUNASRAPI void FsmnVadUninit(FUNASR_HANDLE handle);
+_FUNASRAPI const float FsmnVadGetRetSnippetTime(FUNASR_RESULT result);
+
+// PUNC
+_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle);
+
+//OfflineStream
+_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_RESULT FunOfflineRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunOfflineRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
+
+#ifdef __cplusplus
+
+}
+#endif
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
deleted file mode 100644
index f65efcc..0000000
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ /dev/null
@@ -1,75 +0,0 @@
-#pragma once
-#include <map>
-
-#ifdef WIN32
-#ifdef _FUNASR_API_EXPORT
-#define _FUNASRAPI __declspec(dllexport)
-#else
-#define _FUNASRAPI __declspec(dllimport)
-#endif
-#else
-#define _FUNASRAPI
-#endif
-
-#ifndef _WIN32
-#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
-#else
-#define FUNASR_CALLBCK_PREFIX __stdcall
-#endif
-
-#ifdef __cplusplus
-
-extern "C" {
-#endif
-
-typedef void* FUNASR_HANDLE;
-typedef void* FUNASR_RESULT;
-typedef unsigned char FUNASR_BOOL;
-
-#define FUNASR_TRUE 1
-#define FUNASR_FALSE 0
-#define QM_DEFAULT_THREAD_NUM 4
-
-typedef enum
-{
- RASR_NONE=-1,
- RASRM_CTC_GREEDY_SEARCH=0,
- RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
- RASRM_ATTENSION_RESCORING = 2,
-}FUNASR_MODE;
-
-typedef enum {
- FUNASR_MODEL_PADDLE = 0,
- FUNASR_MODEL_PADDLE_2 = 1,
- FUNASR_MODEL_K2 = 2,
- FUNASR_MODEL_PARAFORMER = 3,
-}FUNASR_MODEL_TYPE;
-
-typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
-
-// // ASR
-_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
-
-_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-
-_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
-_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
-_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
-_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
-_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
-
-// VAD
-_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
-
-_FUNASRAPI FUNASR_RESULT FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-_FUNASRAPI FUNASR_RESULT FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-_FUNASRAPI FUNASR_RESULT FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-_FUNASRAPI FUNASR_RESULT FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
-
-#ifdef __cplusplus
-
-}
-#endif
diff --git a/funasr/runtime/onnxruntime/include/model.h b/funasr/runtime/onnxruntime/include/model.h
index 4b4b582..44bd022 100644
--- a/funasr/runtime/onnxruntime/include/model.h
+++ b/funasr/runtime/onnxruntime/include/model.h
@@ -4,19 +4,17 @@
#include <string>
#include <map>
-
+namespace funasr {
class Model {
public:
virtual ~Model(){};
virtual void Reset() = 0;
+ virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0;
virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
virtual std::string Forward(float *din, int len, int flag) = 0;
virtual std::string Rescoring() = 0;
- virtual std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data)=0;
- virtual std::string AddPunc(const char* sz_input)=0;
- virtual bool UseVad() =0;
- virtual bool UsePunc() =0;
};
Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
+} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/include/offline-stream.h b/funasr/runtime/onnxruntime/include/offline-stream.h
new file mode 100644
index 0000000..a9ce88e
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/offline-stream.h
@@ -0,0 +1,30 @@
+#ifndef OFFLINE_STREAM_H
+#define OFFLINE_STREAM_H
+
+#include <memory>
+#include <string>
+#include <map>
+#include "model.h"
+#include "punc-model.h"
+#include "vad-model.h"
+
+namespace funasr {
+class OfflineStream {
+ public:
+ OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
+ ~OfflineStream(){};
+
+ std::unique_ptr<VadModel> vad_handle;
+ std::unique_ptr<Model> asr_handle;
+ std::unique_ptr<PuncModel> punc_handle;
+ bool UseVad(){return use_vad;};
+ bool UsePunc(){return use_punc;};
+
+ private:
+ bool use_vad=false;
+ bool use_punc=false;
+};
+
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
+} // namespace funasr
+#endif
diff --git a/funasr/runtime/onnxruntime/include/punc-model.h b/funasr/runtime/onnxruntime/include/punc-model.h
new file mode 100644
index 0000000..da7ff60
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/punc-model.h
@@ -0,0 +1,20 @@
+
+#ifndef PUNC_MODEL_H
+#define PUNC_MODEL_H
+
+#include <string>
+#include <map>
+#include <vector>
+
+namespace funasr {
+class PuncModel {
+ public:
+ virtual ~PuncModel(){};
+ virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
+ virtual std::vector<int> Infer(std::vector<int32_t> input_data)=0;
+ virtual std::string AddPunc(const char* sz_input)=0;
+};
+
+PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
+} // namespace funasr
+#endif
diff --git a/funasr/runtime/onnxruntime/include/vad-model.h b/funasr/runtime/onnxruntime/include/vad-model.h
new file mode 100644
index 0000000..2a8d6e4
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/vad-model.h
@@ -0,0 +1,29 @@
+
+#ifndef VAD_MODEL_H
+#define VAD_MODEL_H
+
+#include <string>
+#include <map>
+#include <vector>
+
+namespace funasr {
+class VadModel {
+ public:
+ virtual ~VadModel(){};
+ virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
+ virtual std::vector<std::vector<int>> Infer(const std::vector<float> &waves)=0;
+ virtual void ReadModel(const char* vad_model)=0;
+ virtual void LoadConfigFromYaml(const char* filename)=0;
+ virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+ const std::vector<float> &waves)=0;
+ virtual std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats)=0;
+ virtual void Forward(
+ const std::vector<std::vector<float>> &chunk_feats,
+ std::vector<std::vector<float>> *out_prob)=0;
+ virtual void LoadCmvn(const char *filename)=0;
+ virtual void InitCache()=0;
+};
+
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
+} // namespace funasr
+#endif
diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md
index 7a96261..5b42c30 100644
--- a/funasr/runtime/onnxruntime/readme.md
+++ b/funasr/runtime/onnxruntime/readme.md
@@ -4,9 +4,10 @@
### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
```shell
-pip3 install torch torchaudio
-pip install -U modelscope
-pip install -U funasr
+# pip3 install torch torchaudio
+pip install -U modelscope funasr
+# For the users in China, you could install with the command:
+# pip install -U modelscope funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
@@ -40,41 +41,110 @@
```
## Run the demo
+### funasr-onnx-offline
```shell
./funasr-onnx-offline [--wav-scp <string>] [--wav-path <string>]
- [--punc-config <string>] [--punc-model <string>]
- --am-config <string> --am-cmvn <string>
- --am-model <string> [--vad-config <string>]
- [--vad-cmvn <string>] [--vad-model <string>] [--]
- [--version] [-h]
+ [--punc-quant <string>] [--punc-dir <string>]
+ [--vad-quant <string>] [--vad-dir <string>]
+ [--quantize <string>] --model-dir <string>
+ [--] [--version] [-h]
Where:
+ --model-dir <string>
+ (required) the asr model path, which contains model.onnx, config.yaml, am.mvn
+ --quantize <string>
+ false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+
+ --vad-dir <string>
+ the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+ --vad-quant <string>
+ false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+
+ --punc-dir <string>
+ the punc model path, which contains model.onnx, punc.yaml
+ --punc-quant <string>
+ false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
+
+ --wav-scp <string>
+ wave scp path
+ --wav-path <string>
+ wave file path
+
+ Required: --model-dir <string>
+ If use vad, please add: --vad-dir <string>
+ If use punc, please add: --punc-dir <string>
+
+For example:
+./funasr-onnx-offline \
+ --model-dir ./asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch \
+ --quantize true \
+ --vad-dir ./asrmodel/speech_fsmn_vad_zh-cn-16k-common-pytorch \
+ --punc-dir ./asrmodel/punc_ct-transformer_zh-cn-common-vocab272727-pytorch \
+ --wav-path ./vad_example.wav
+```
+
+### funasr-onnx-offline-vad
+```shell
+./funasr-onnx-offline-vad [--wav-scp <string>] [--wav-path <string>]
+ [--quantize <string>] --model-dir <string>
+ [--] [--version] [-h]
+Where:
+ --model-dir <string>
+ (required) the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+ --quantize <string>
+ false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
--wav-scp <string>
wave scp path
--wav-path <string>
wave file path
- --punc-config <string>
- punc config path
- --punc-model <string>
- punc model path
+ Required: --model-dir <string>
- --am-config <string>
- (required) am config path
- --am-cmvn <string>
- (required) am cmvn path
- --am-model <string>
- (required) am model path
+For example:
+./funasr-onnx-offline-vad \
+ --model-dir ./asrmodel/speech_fsmn_vad_zh-cn-16k-common-pytorch \
+ --wav-path ./vad_example.wav
+```
- --vad-config <string>
- vad config path
- --vad-cmvn <string>
- vad cmvn path
- --vad-model <string>
- vad model path
-
- Required: --am-config <string> --am-cmvn <string> --am-model <string>
- If use vad, please add: [--vad-config <string>] [--vad-cmvn <string>] [--vad-model <string>]
- If use punc, please add: [--punc-config <string>] [--punc-model <string>]
+### funasr-onnx-offline-punc
+```shell
+./funasr-onnx-offline-punc [--txt-path <string>] [--quantize <string>]
+ --model-dir <string> [--] [--version] [-h]
+Where:
+ --model-dir <string>
+ (required) the punc model path, which contains model.onnx, punc.yaml
+ --quantize <string>
+ false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+ --txt-path <string>
+ txt file path, one sentence per line
+
+ Required: --model-dir <string>
+
+For example:
+./funasr-onnx-offline-punc \
+ --model-dir ./asrmodel/punc_ct-transformer_zh-cn-common-vocab272727-pytorch \
+ --txt-path ./punc_example.txt
+```
+### funasr-onnx-offline-rtf
+```shell
+./funasr-onnx-offline-rtf --thread-num <int32_t> --wav-scp <string>
+ [--quantize <string>] --model-dir <string>
+ [--] [--version] [-h]
+Where:
+ --thread-num <int32_t>
+ (required) multi-thread num for rtf
+ --model-dir <string>
+ (required) the model path, which contains model.onnx, config.yaml, am.mvn
+ --quantize <string>
+ false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+ --wav-scp <string>
+ (required) wave scp path
+
+For example:
+./funasr-onnx-offline-rtf \
+ --model-dir ./asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch \
+ --quantize true \
+ --wav-scp ./aishell1_test.scp \
+ --thread-num 32
```
## Acknowledge
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index 28a67b4..341a16a 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -26,7 +26,11 @@
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
+add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
+add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
+target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
+target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
diff --git a/funasr/runtime/onnxruntime/src/alignedmem.cpp b/funasr/runtime/onnxruntime/src/alignedmem.cpp
index d3e4b82..9c7d323 100644
--- a/funasr/runtime/onnxruntime/src/alignedmem.cpp
+++ b/funasr/runtime/onnxruntime/src/alignedmem.cpp
@@ -1,4 +1,6 @@
#include "precomp.h"
+
+namespace funasr {
void *AlignedMalloc(size_t alignment, size_t required_bytes)
{
void *p1; // original block
@@ -16,3 +18,4 @@
{
free(((void **)p)[-1]);
}
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/alignedmem.h b/funasr/runtime/onnxruntime/src/alignedmem.h
index e2b640a..e4b9a78 100644
--- a/funasr/runtime/onnxruntime/src/alignedmem.h
+++ b/funasr/runtime/onnxruntime/src/alignedmem.h
@@ -2,7 +2,9 @@
#ifndef ALIGNEDMEM_H
#define ALIGNEDMEM_H
+namespace funasr {
extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
extern void AlignedFree(void *p);
+} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index d104500..6d63d67 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -11,6 +11,7 @@
using namespace std;
+namespace funasr {
// see http://soundfile.sapp.org/doc/WaveFormat/
// Note: We assume little endian here
struct WaveHeader {
@@ -237,6 +238,24 @@
LOG(ERROR) << "Failed to read " << filename;
return false;
}
+
+ if (!header.Validate()) {
+ return false;
+ }
+
+ header.SeekToDataChunk(is);
+ if (!is) {
+ return false;
+ }
+
+ if (!header.Validate()) {
+ return false;
+ }
+
+ header.SeekToDataChunk(is);
+ if (!is) {
+ return false;
+ }
*sampling_rate = header.sample_rate;
// header.subchunk2_size contains the number of bytes in the data.
@@ -380,8 +399,10 @@
FILE* fp;
fp = fopen(filename, "rb");
if (fp == nullptr)
+ {
LOG(ERROR) << "Failed to read " << filename;
return false;
+ }
fseek(fp, 0, SEEK_END);
uint32_t n_file_len = ftell(fp);
fseek(fp, 0, SEEK_SET);
@@ -494,7 +515,7 @@
delete frame;
}
-void Audio::Split(Model* recog_obj)
+void Audio::Split(OfflineStream* offline_stream)
{
AudioFrame *frame;
@@ -505,7 +526,7 @@
frame = NULL;
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
- vector<std::vector<int>> vad_segments = recog_obj->VadSeg(pcm_data);
+ vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
int seg_sample = MODEL_SAMPLE_RATE/1000;
for(vector<int> segment:vad_segments)
{
@@ -518,3 +539,20 @@
frame = NULL;
}
}
+
+
+void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
+{
+ AudioFrame *frame;
+
+ frame = frame_queue.front();
+ frame_queue.pop();
+ int sp_len = frame->GetLen();
+ delete frame;
+ frame = NULL;
+
+ std::vector<float> pcm_data(speech_data, speech_data+sp_len);
+ vad_segments = vad_obj->Infer(pcm_data);
+}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index fbbda74..d0882c6 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -1,10 +1,18 @@
#pragma once
#include <algorithm>
+
+namespace funasr {
typedef struct
{
std::string msg;
float snippet_time;
}FUNASR_RECOG_RESULT;
+
+typedef struct
+{
+ std::vector<std::vector<int>>* segments;
+ float snippet_time;
+}FUNASR_VAD_RESULT;
#ifdef _WIN32
@@ -52,3 +60,4 @@
inline static size_t Argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last));
}
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
index ecde636..38a5a70 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -5,6 +5,7 @@
#include "precomp.h"
+namespace funasr {
CTTransformer::CTTransformer()
:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
{
@@ -54,7 +55,7 @@
int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
int nCurBatch = -1;
int nSentEnd = -1, nLastCommaIndex = -1;
- vector<int64_t> RemainIDs; //
+ vector<int32_t> RemainIDs; //
vector<string> RemainStr; //
vector<int> NewPunctuation; //
vector<string> NewString; //
@@ -64,7 +65,7 @@
for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
{
nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
- vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
+ vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
@@ -141,12 +142,13 @@
return strResult;
}
-vector<int> CTTransformer::Infer(vector<int64_t> input_data)
+vector<int> CTTransformer::Infer(vector<int32_t> input_data)
{
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
vector<int> punction;
std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
- Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
+ Ort::Value onnx_input = Ort::Value::CreateTensor<int32_t>(
+ m_memoryInfo,
input_data.data(),
input_data.size(),
input_shape_.data(),
@@ -185,3 +187,4 @@
return punction;
}
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.h b/funasr/runtime/onnxruntime/src/ct-transformer.h
index d965bb3..49ed1b7 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.h
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.h
@@ -5,7 +5,8 @@
#pragma once
-class CTTransformer {
+namespace funasr {
+class CTTransformer : public PuncModel {
/**
* Author: Speech Lab of DAMO Academy, Alibaba Group
* CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -27,6 +28,7 @@
CTTransformer();
void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
~CTTransformer();
- vector<int> Infer(vector<int64_t> input_data);
+ vector<int> Infer(vector<int32_t> input_data);
string AddPunc(const char* sz_input);
};
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/e2e-vad.h b/funasr/runtime/onnxruntime/src/e2e-vad.h
index 90f2635..5ece1f8 100644
--- a/funasr/runtime/onnxruntime/src/e2e-vad.h
+++ b/funasr/runtime/onnxruntime/src/e2e-vad.h
@@ -1,7 +1,10 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
+ * Contributed by zhuzizyf(China Telecom).
*/
+
+#pragma once
#include <utility>
#include <vector>
@@ -13,7 +16,7 @@
#include <numeric>
#include <cassert>
-
+namespace funasr {
enum class VadStateMachine {
kVadInStateStartPointNotDetected = 1,
kVadInStateInSpeechSegment = 2,
@@ -381,10 +384,11 @@
int max_end_sil_frame_cnt_thresh;
float speech_noise_thres;
std::vector<std::vector<float>> scores;
+ int idx_pre_chunk = 0;
bool max_time_out;
std::vector<float> decibel;
- std::vector<float> data_buf;
- std::vector<float> data_buf_all;
+ int data_buf_size = 0;
+ int data_buf_all_size = 0;
std::vector<float> waveform;
void AllResetDetection() {
@@ -409,10 +413,11 @@
max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres;
speech_noise_thres = vad_opts.speech_noise_thres;
scores.clear();
+ idx_pre_chunk = 0;
max_time_out = false;
decibel.clear();
- data_buf.clear();
- data_buf_all.clear();
+ int data_buf_size = 0;
+ int data_buf_all_size = 0;
waveform.clear();
ResetDetection();
}
@@ -432,18 +437,17 @@
void ComputeDecibel() {
int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000);
int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
- if (data_buf_all.empty()) {
- data_buf_all = waveform;
- data_buf = data_buf_all;
+ if (data_buf_all_size == 0) {
+ data_buf_all_size = waveform.size();
+ data_buf_size = data_buf_all_size;
} else {
- data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end());
+ data_buf_all_size += waveform.size();
}
- for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) {
+ for (int offset = 0; offset + frame_sample_length -1 < waveform.size(); offset += frame_shift_length) {
float sum = 0.0;
for (int i = 0; i < frame_sample_length; i++) {
sum += waveform[offset + i] * waveform[offset + i];
}
-// float decibel = 10 * log10(sum + 0.000001);
this->decibel.push_back(10 * log10(sum + 0.000001));
}
}
@@ -451,29 +455,16 @@
void ComputeScores(const std::vector<std::vector<float>> &scores) {
vad_opts.nn_eval_block_size = scores.size();
frm_cnt += scores.size();
- if (this->scores.empty()) {
- this->scores = scores; // the first calculation
- } else {
- this->scores.insert(this->scores.end(), scores.begin(), scores.end());
- }
+ this->scores = scores;
}
void PopDataBufTillFrame(int frame_idx) {
int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
- int start_pos=-1;
- int data_length= data_buf.size();
while (data_buf_start_frame < frame_idx) {
- if (data_length >= frame_sample_length) {
+ if (data_buf_size >= frame_sample_length) {
data_buf_start_frame += 1;
- start_pos= data_buf_start_frame* frame_sample_length;
- data_length=data_buf_all.size()-start_pos;
- } else {
- break;
+ data_buf_size = data_buf_all_size - data_buf_start_frame * frame_sample_length;
}
- }
- if (start_pos!=-1){
- data_buf.resize(data_length);
- std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin());
}
}
@@ -487,9 +478,9 @@
expected_sample_number += int(extra_sample);
}
if (end_point_is_sent_end) {
- expected_sample_number = std::max(expected_sample_number, int(data_buf.size()));
+ expected_sample_number = std::max(expected_sample_number, data_buf_size);
}
- if (data_buf.size() < expected_sample_number) {
+ if (data_buf_size < expected_sample_number) {
std::cout << "error in calling pop data_buf\n";
}
if (output_data_buf.size() == 0 || first_frm_is_start_point) {
@@ -503,27 +494,20 @@
if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
std::cout << "warning\n";
}
- int out_pos = (int) cur_seg.buffer.size();
+
int data_to_pop;
if (end_point_is_sent_end) {
data_to_pop = expected_sample_number;
} else {
data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
}
- if (data_to_pop > int(data_buf.size())) {
+ if (data_to_pop > data_buf_size) {
std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n";
- data_to_pop = (int) data_buf.size();
- expected_sample_number = (int) data_buf.size();
+ data_to_pop = data_buf_size;
+ expected_sample_number = data_buf_size;
}
cur_seg.doa = 0;
- for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) {
- cur_seg.buffer.push_back(data_buf.back());
- out_pos++;
- }
- for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++) {
- cur_seg.buffer.push_back(data_buf.back());
- out_pos++;
- }
+
if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
std::cout << "Something wrong with the VAD algorithm\n";
}
@@ -619,7 +603,7 @@
if (sil_pdf_ids.size() > 0) {
std::vector<float> sil_pdf_scores;
for (auto sil_pdf_id: sil_pdf_ids) {
- sil_pdf_scores.push_back(scores[t][sil_pdf_id]);
+ sil_pdf_scores.push_back(scores[t - idx_pre_chunk][sil_pdf_id]);
}
sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0);
noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio;
@@ -663,6 +647,7 @@
frame_state = GetFrameState(frm_cnt - 1 - i);
DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
}
+ idx_pre_chunk += scores.size();
return 0;
}
@@ -797,5 +782,4 @@
};
-
-
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index fbb682b..0a646f0 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -6,8 +6,9 @@
#include <fstream>
#include "precomp.h"
-void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config) {
- session_options_.SetIntraOpNumThreads(1);
+namespace funasr {
+void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) {
+ session_options_.SetIntraOpNumThreads(thread_num);
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options_.DisableCpuMemArena();
@@ -296,5 +297,10 @@
void FsmnVad::Test() {
}
+FsmnVad::~FsmnVad() {
+}
+
FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
}
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h
index 1d5f68c..7a6707c 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.h
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -8,7 +8,8 @@
#include "precomp.h"
-class FsmnVad {
+namespace funasr {
+class FsmnVad : public VadModel {
/**
* Author: Speech Lab of DAMO Academy, Alibaba Group
* Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -17,9 +18,9 @@
public:
FsmnVad();
+ ~FsmnVad();
void Test();
- void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config);
-
+ void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
std::vector<std::vector<int>> Infer(const std::vector<float> &waves);
void Reset();
@@ -63,5 +64,5 @@
int lfr_n = VAD_LFR_N;
};
-
+} // namespace funasr
#endif //VAD_SERVER_FSMNVAD_H
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
new file mode 100644
index 0000000..a8ee9a9
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
@@ -0,0 +1,98 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+
+using namespace std;
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-offline-punc", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", false, "", "string");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(txt_path);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(txt_path, TXT_PATH, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = 1;
+ FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num);
+
+ if (!punc_hanlde)
+ {
+ LOG(ERROR) << "FunASR init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read txt_path
+ vector<string> txt_list;
+
+ if(model_path.find(TXT_PATH)!=model_path.end()){
+ ifstream in(model_path.at(TXT_PATH));
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ txt_list.emplace_back(line);
+ }
+ in.close();
+ }
+
+ long taking_micros = 0;
+ for(auto& txt_str : txt_list){
+ gettimeofday(&start, NULL);
+ string result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO)<<"Results: "<<result;
+ }
+
+ LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+ CTTransformerUninit(punc_hanlde);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
index 45b6196..76624e7 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -10,7 +10,7 @@
#endif
#include <glog/logging.h>
-#include "libfunasrapi.h"
+#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
@@ -91,41 +91,21 @@
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
- TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
- TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
- TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
-
- TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", false, "", "string");
- TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", false, "", "string");
- TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", false, "", "string");
-
- TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
- TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", true, "", "string");
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
- cmd.add(vad_model);
- cmd.add(vad_cmvn);
- cmd.add(vad_config);
- cmd.add(am_model);
- cmd.add(am_cmvn);
- cmd.add(am_config);
- cmd.add(punc_model);
- cmd.add(punc_config);
+ cmd.add(model_dir);
+ cmd.add(quantize);
cmd.add(wav_scp);
cmd.add(thread_num);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
- GetValue(vad_model, VAD_MODEL_PATH, model_path);
- GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
- GetValue(vad_config, VAD_CONFIG_PATH, model_path);
- GetValue(am_model, AM_MODEL_PATH, model_path);
- GetValue(am_cmvn, AM_CMVN_PATH, model_path);
- GetValue(am_config, AM_CONFIG_PATH, model_path);
- GetValue(punc_model, PUNC_MODEL_PATH, model_path);
- GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
GetValue(wav_scp, WAV_SCP, model_path);
struct timeval start, end;
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
new file mode 100644
index 0000000..37513ae
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
@@ -0,0 +1,143 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <vector>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+
+using namespace std;
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
+void print_segs(vector<vector<int>>* vec) {
+ string seg_out="[";
+ for (int i = 0; i < vec->size(); i++) {
+ vector<int> inner_vec = (*vec)[i];
+ seg_out += "[";
+ for (int j = 0; j < inner_vec.size(); j++) {
+ seg_out += to_string(inner_vec[j]);
+ if (j != inner_vec.size() - 1) {
+ seg_out += ",";
+ }
+ }
+ seg_out += "]";
+ if (i != vec->size() - 1) {
+ seg_out += ",";
+ }
+ }
+ seg_out += "]";
+ LOG(INFO)<<seg_out;
+}
+
+int main(int argc, char *argv[])
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+
+ TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
+ TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(wav_path);
+ cmd.add(wav_scp);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(wav_path, WAV_PATH, model_path);
+ GetValue(wav_scp, WAV_SCP, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = 1;
+ FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num);
+
+ if (!vad_hanlde)
+ {
+ LOG(ERROR) << "FunVad init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read wav_path and wav_scp
+ vector<string> wav_list;
+
+ if(model_path.find(WAV_PATH)!=model_path.end()){
+ wav_list.emplace_back(model_path.at(WAV_PATH));
+ }
+ if(model_path.find(WAV_SCP)!=model_path.end()){
+ ifstream in(model_path.at(WAV_SCP));
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ }
+ in.close();
+ }
+
+ float snippet_time = 0.0f;
+ long taking_micros = 0;
+ for(auto& wav_file : wav_list){
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result=FsmnVadWavFile(vad_hanlde, wav_file.c_str(), RASR_NONE, NULL);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+ if (result)
+ {
+ vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
+ print_segs(vad_segments);
+ snippet_time += FsmnVadGetRetSnippetTime(result);
+ FsmnVadFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("No return data!\n");
+ }
+ }
+
+ LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
+ LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+ LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+ FsmnVadUninit(vad_hanlde);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
index 2d61bbb..343039d 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
@@ -14,7 +14,7 @@
#include <sstream>
#include <map>
#include <glog/logging.h>
-#include "libfunasrapi.h"
+#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
@@ -28,55 +28,46 @@
}
}
-int main(int argc, char *argv[])
+int main(int argc, char** argv)
{
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
- TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
- TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
- TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
-
- TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
- TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
- TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
-
- TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
- TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
- cmd.add(vad_model);
- cmd.add(vad_cmvn);
- cmd.add(vad_config);
- cmd.add(am_model);
- cmd.add(am_cmvn);
- cmd.add(am_config);
- cmd.add(punc_model);
- cmd.add(punc_config);
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
cmd.add(wav_path);
cmd.add(wav_scp);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
- GetValue(vad_model, VAD_MODEL_PATH, model_path);
- GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
- GetValue(vad_config, VAD_CONFIG_PATH, model_path);
- GetValue(am_model, AM_MODEL_PATH, model_path);
- GetValue(am_cmvn, AM_CMVN_PATH, model_path);
- GetValue(am_config, AM_CONFIG_PATH, model_path);
- GetValue(punc_model, PUNC_MODEL_PATH, model_path);
- GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(wav_path, WAV_PATH, model_path);
GetValue(wav_scp, WAV_SCP, model_path);
-
struct timeval start, end;
gettimeofday(&start, NULL);
int thread_num = 1;
- FUNASR_HANDLE asr_hanlde=FunASRInit(model_path, thread_num);
+ FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
if (!asr_hanlde)
{
@@ -116,7 +107,7 @@
long taking_micros = 0;
for(auto& wav_file : wav_list){
gettimeofday(&start, NULL);
- FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
+ FUNASR_RESULT result=FunOfflineRecogFile(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -124,8 +115,7 @@
if (result)
{
string msg = FunASRGetResult(result, 0);
- setbuf(stdout, NULL);
- printf("Result: %s \n", msg.c_str());
+ LOG(INFO)<<"Result: "<<msg;
snippet_time += FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}
@@ -138,7 +128,7 @@
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
- FunASRUninit(asr_hanlde);
+ FunOfflineUninit(asr_hanlde);
return 0;
}
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
new file mode 100644
index 0000000..893ba70
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -0,0 +1,362 @@
+#include "precomp.h"
+#ifdef __cplusplus
+
+extern "C" {
+#endif
+
+ // APIs for Init
+ _FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
+ {
+ funasr::Model* mm = funasr::CreateModel(model_path, thread_num);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num)
+ {
+ funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num)
+ {
+ funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
+ {
+ funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
+ return mm;
+ }
+
+ // APIs for ASR Infer
+ _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::Model* recog_obj = (funasr::Model*)handle;
+ if (!recog_obj)
+ return nullptr;
+
+ int32_t sampling_rate = -1;
+ funasr::Audio audio(1);
+ if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+
+ float* buff;
+ int len;
+ int flag=0;
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ int n_step = 0;
+ int n_total = audio.GetQueueSize();
+ while (audio.Fetch(buff, len, flag) > 0) {
+ string msg = recog_obj->Forward(buff, len, flag);
+ p_result->msg += msg;
+ n_step++;
+ if (fn_callback)
+ fn_callback(n_step, n_total);
+ }
+
+ return p_result;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::Model* recog_obj = (funasr::Model*)handle;
+ if (!recog_obj)
+ return nullptr;
+
+ funasr::Audio audio(1);
+ if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+
+ float* buff;
+ int len;
+ int flag = 0;
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ int n_step = 0;
+ int n_total = audio.GetQueueSize();
+ while (audio.Fetch(buff, len, flag) > 0) {
+ string msg = recog_obj->Forward(buff, len, flag);
+ p_result->msg += msg;
+ n_step++;
+ if (fn_callback)
+ fn_callback(n_step, n_total);
+ }
+
+ return p_result;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::Model* recog_obj = (funasr::Model*)handle;
+ if (!recog_obj)
+ return nullptr;
+
+ funasr::Audio audio(1);
+ if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
+ return nullptr;
+
+ float* buff;
+ int len;
+ int flag = 0;
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ int n_step = 0;
+ int n_total = audio.GetQueueSize();
+ while (audio.Fetch(buff, len, flag) > 0) {
+ string msg = recog_obj->Forward(buff, len, flag);
+ p_result->msg += msg;
+ n_step++;
+ if (fn_callback)
+ fn_callback(n_step, n_total);
+ }
+
+ return p_result;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::Model* recog_obj = (funasr::Model*)handle;
+ if (!recog_obj)
+ return nullptr;
+
+ int32_t sampling_rate = -1;
+ funasr::Audio audio(1);
+ if(!audio.LoadWav(sz_wavfile, &sampling_rate))
+ return nullptr;
+
+ float* buff;
+ int len;
+ int flag = 0;
+ int n_step = 0;
+ int n_total = audio.GetQueueSize();
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ while (audio.Fetch(buff, len, flag) > 0) {
+ string msg = recog_obj->Forward(buff, len, flag);
+ p_result->msg+= msg;
+ n_step++;
+ if (fn_callback)
+ fn_callback(n_step, n_total);
+ }
+
+ return p_result;
+ }
+
+ // APIs for VAD Infer
+ _FUNASRAPI FUNASR_RESULT FsmnVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
+ if (!vad_obj)
+ return nullptr;
+
+ int32_t sampling_rate = -1;
+ funasr::Audio audio(1);
+ if(!audio.LoadWav(sz_wavfile, &sampling_rate))
+ return nullptr;
+
+ funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+
+ vector<std::vector<int>> vad_segments;
+ audio.Split(vad_obj, vad_segments);
+ p_result->segments = new vector<std::vector<int>>(vad_segments);
+
+ return p_result;
+ }
+
+ // APIs for PUNC Infer
+ _FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::PuncModel* punc_obj = (funasr::PuncModel*)handle;
+ if (!punc_obj)
+ return nullptr;
+
+ string punc_res = punc_obj->AddPunc(sz_sentence);
+ return punc_res;
+ }
+
+ // APIs for Offline-stream Infer
+ _FUNASRAPI FUNASR_RESULT FunOfflineRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
+ if (!offline_stream)
+ return nullptr;
+
+ int32_t sampling_rate = -1;
+ funasr::Audio audio(1);
+ if(!audio.LoadWav(sz_wavfile, &sampling_rate))
+ return nullptr;
+ if(offline_stream->UseVad()){
+ audio.Split(offline_stream);
+ }
+
+ float* buff;
+ int len;
+ int flag = 0;
+ int n_step = 0;
+ int n_total = audio.GetQueueSize();
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ while (audio.Fetch(buff, len, flag) > 0) {
+ string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
+ p_result->msg+= msg;
+ n_step++;
+ if (fn_callback)
+ fn_callback(n_step, n_total);
+ }
+ if(offline_stream->UsePunc()){
+ string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
+ p_result->msg = punc_res;
+ }
+
+ return p_result;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunOfflineRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ {
+ funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
+ if (!offline_stream)
+ return nullptr;
+
+ funasr::Audio audio(1);
+ if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+ if(offline_stream->UseVad()){
+ audio.Split(offline_stream);
+ }
+
+ float* buff;
+ int len;
+ int flag = 0;
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ int n_step = 0;
+ int n_total = audio.GetQueueSize();
+ while (audio.Fetch(buff, len, flag) > 0) {
+ string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
+ p_result->msg += msg;
+ n_step++;
+ if (fn_callback)
+ fn_callback(n_step, n_total);
+ }
+ if(offline_stream->UsePunc()){
+ string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
+ p_result->msg = punc_res;
+ }
+
+ return p_result;
+ }
+
+ _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
+ {
+ if (!result)
+ return 0;
+
+ return 1;
+ }
+
+ // APIs for GetRetSnippetTime
+ _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result)
+ {
+ if (!result)
+ return 0.0f;
+
+ return ((funasr::FUNASR_RECOG_RESULT*)result)->snippet_time;
+ }
+
+ _FUNASRAPI const float FsmnVadGetRetSnippetTime(FUNASR_RESULT result)
+ {
+ if (!result)
+ return 0.0f;
+
+ return ((funasr::FUNASR_VAD_RESULT*)result)->snippet_time;
+ }
+
+ // APIs for GetResult
+ _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index)
+ {
+ funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+ if(!p_result)
+ return nullptr;
+
+ return p_result->msg.c_str();
+ }
+
+ _FUNASRAPI vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index)
+ {
+ funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
+ if(!p_result)
+ return nullptr;
+
+ return p_result->segments;
+ }
+
+ // APIs for FreeResult
+ _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result)
+ {
+ if (result)
+ {
+ delete (funasr::FUNASR_RECOG_RESULT*)result;
+ }
+ }
+
+ _FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result)
+ {
+ funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
+ if (p_result)
+ {
+ if(p_result->segments){
+ delete p_result->segments;
+ }
+ delete p_result;
+ }
+ }
+
+ // APIs for Uninit
+ _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
+ {
+ funasr::Model* recog_obj = (funasr::Model*)handle;
+
+ if (!recog_obj)
+ return;
+
+ delete recog_obj;
+ }
+
+ _FUNASRAPI void FsmnVadUninit(FUNASR_HANDLE handle)
+ {
+ funasr::VadModel* recog_obj = (funasr::VadModel*)handle;
+
+ if (!recog_obj)
+ return;
+
+ delete recog_obj;
+ }
+
+ _FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle)
+ {
+ funasr::PuncModel* punc_obj = (funasr::PuncModel*)handle;
+
+ if (!punc_obj)
+ return;
+
+ delete punc_obj;
+ }
+
+ _FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle)
+ {
+ funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
+
+ if (!offline_stream)
+ return;
+
+ delete offline_stream;
+ }
+
+#ifdef __cplusplus
+
+}
+#endif
+
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
deleted file mode 100644
index 01aa38a..0000000
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ /dev/null
@@ -1,210 +0,0 @@
-#include "precomp.h"
-#ifdef __cplusplus
-
-extern "C" {
-#endif
-
- // APIs for funasr
- _FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
- {
- Model* mm = CreateModel(model_path, thread_num);
- return mm;
- }
-
- _FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num)
- {
- Model* mm = CreateModel(model_path, thread_num);
- return mm;
- }
-
- _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
- {
- Model* recog_obj = (Model*)handle;
- if (!recog_obj)
- return nullptr;
-
- int32_t sampling_rate = -1;
- Audio audio(1);
- if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
- return nullptr;
- if(recog_obj->UseVad()){
- audio.Split(recog_obj);
- }
-
- float* buff;
- int len;
- int flag=0;
- FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
- p_result->snippet_time = audio.GetTimeLen();
- int n_step = 0;
- int n_total = audio.GetQueueSize();
- while (audio.Fetch(buff, len, flag) > 0) {
- string msg = recog_obj->Forward(buff, len, flag);
- p_result->msg += msg;
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
- }
- if(recog_obj->UsePunc()){
- string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
- p_result->msg = punc_res;
- }
-
- return p_result;
- }
-
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
- {
- Model* recog_obj = (Model*)handle;
- if (!recog_obj)
- return nullptr;
-
- Audio audio(1);
- if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
- return nullptr;
- if(recog_obj->UseVad()){
- audio.Split(recog_obj);
- }
-
- float* buff;
- int len;
- int flag = 0;
- FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
- p_result->snippet_time = audio.GetTimeLen();
- int n_step = 0;
- int n_total = audio.GetQueueSize();
- while (audio.Fetch(buff, len, flag) > 0) {
- string msg = recog_obj->Forward(buff, len, flag);
- p_result->msg += msg;
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
- }
- if(recog_obj->UsePunc()){
- string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
- p_result->msg = punc_res;
- }
-
- return p_result;
- }
-
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
- {
- Model* recog_obj = (Model*)handle;
- if (!recog_obj)
- return nullptr;
-
- Audio audio(1);
- if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
- return nullptr;
- if(recog_obj->UseVad()){
- audio.Split(recog_obj);
- }
-
- float* buff;
- int len;
- int flag = 0;
- FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
- p_result->snippet_time = audio.GetTimeLen();
- int n_step = 0;
- int n_total = audio.GetQueueSize();
- while (audio.Fetch(buff, len, flag) > 0) {
- string msg = recog_obj->Forward(buff, len, flag);
- p_result->msg += msg;
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
- }
- if(recog_obj->UsePunc()){
- string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
- p_result->msg = punc_res;
- }
-
- return p_result;
- }
-
- _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
- {
- Model* recog_obj = (Model*)handle;
- if (!recog_obj)
- return nullptr;
-
- int32_t sampling_rate = -1;
- Audio audio(1);
- if(!audio.LoadWav(sz_wavfile, &sampling_rate))
- return nullptr;
- if(recog_obj->UseVad()){
- audio.Split(recog_obj);
- }
-
- float* buff;
- int len;
- int flag = 0;
- int n_step = 0;
- int n_total = audio.GetQueueSize();
- FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
- p_result->snippet_time = audio.GetTimeLen();
- while (audio.Fetch(buff, len, flag) > 0) {
- string msg = recog_obj->Forward(buff, len, flag);
- p_result->msg+= msg;
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
- }
- if(recog_obj->UsePunc()){
- string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
- p_result->msg = punc_res;
- }
-
- return p_result;
- }
-
- _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
- {
- if (!result)
- return 0;
-
- return 1;
- }
-
-
- _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result)
- {
- if (!result)
- return 0.0f;
-
- return ((FUNASR_RECOG_RESULT*)result)->snippet_time;
- }
-
- _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index)
- {
- FUNASR_RECOG_RESULT * p_result = (FUNASR_RECOG_RESULT*)result;
- if(!p_result)
- return nullptr;
-
- return p_result->msg.c_str();
- }
-
- _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result)
- {
- if (result)
- {
- delete (FUNASR_RECOG_RESULT*)result;
- }
- }
-
- _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
- {
- Model* recog_obj = (Model*)handle;
-
- if (!recog_obj)
- return;
-
- delete recog_obj;
- }
-
-#ifdef __cplusplus
-
-}
-#endif
-
diff --git a/funasr/runtime/onnxruntime/src/model.cpp b/funasr/runtime/onnxruntime/src/model.cpp
index 52ce7ba..6badde6 100644
--- a/funasr/runtime/onnxruntime/src/model.cpp
+++ b/funasr/runtime/onnxruntime/src/model.cpp
@@ -1,8 +1,23 @@
#include "precomp.h"
+namespace funasr {
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
{
+ string am_model_path;
+ string am_cmvn_path;
+ string am_config_path;
+
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ }
+ am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
+ am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+
Model *mm;
- mm = new paraformer::Paraformer(model_path, thread_num);
+ mm = new Paraformer();
+ mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
return mm;
}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/offline-stream.cpp b/funasr/runtime/onnxruntime/src/offline-stream.cpp
new file mode 100644
index 0000000..8170129
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/offline-stream.cpp
@@ -0,0 +1,64 @@
+#include "precomp.h"
+
+namespace funasr {
+OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
+{
+ // VAD model
+ if(model_path.find(VAD_DIR) != model_path.end()){
+ use_vad = true;
+ string vad_model_path;
+ string vad_cmvn_path;
+ string vad_config_path;
+
+ vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
+ if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
+ vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
+ }
+ vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
+ vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
+ vad_handle = make_unique<FsmnVad>();
+ vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ }
+
+ // AM model
+ if(model_path.find(MODEL_DIR) != model_path.end()){
+ string am_model_path;
+ string am_cmvn_path;
+ string am_config_path;
+
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ }
+ am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
+ am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+
+ asr_handle = make_unique<Paraformer>();
+ asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+ }
+
+ // PUNC model
+ if(model_path.find(PUNC_DIR) != model_path.end()){
+ use_punc = true;
+ string punc_model_path;
+ string punc_config_path;
+
+ punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
+ if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
+ punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
+ }
+ punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+
+ punc_handle = make_unique<CTTransformer>();
+ punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ }
+}
+
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
+{
+ OfflineStream *mm;
+ mm = new OfflineStream(model_path, thread_num);
+ return mm;
+}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/online-feature.cpp b/funasr/runtime/onnxruntime/src/online-feature.cpp
index 3f57e0b..a21589c 100644
--- a/funasr/runtime/onnxruntime/src/online-feature.cpp
+++ b/funasr/runtime/onnxruntime/src/online-feature.cpp
@@ -1,11 +1,13 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
+ * Contributed by zhuzizyf(China Telecom).
*/
#include "online-feature.h"
#include <utility>
+namespace funasr {
OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
std::vector<std::vector<float>> cmvns)
: sample_rate_(sample_rate),
@@ -131,3 +133,5 @@
}
}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h
index decaaf4..16e6e4b 100644
--- a/funasr/runtime/onnxruntime/src/online-feature.h
+++ b/funasr/runtime/onnxruntime/src/online-feature.h
@@ -1,13 +1,14 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
+ * Contributed by zhuzizyf(China Telecom).
*/
-
+#pragma once
#include <vector>
#include "precomp.h"
using namespace std;
-
+namespace funasr {
class OnlineFeature {
public:
@@ -53,3 +54,5 @@
bool input_finished_ = false;
};
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 136d228..74366a0 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -6,67 +6,14 @@
#include "precomp.h"
using namespace std;
-using namespace paraformer;
-Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
+namespace funasr {
+
+Paraformer::Paraformer()
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
-
- // VAD model
- if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
- use_vad = true;
- string vad_model_path;
- string vad_cmvn_path;
- string vad_config_path;
-
- try{
- vad_model_path = model_path.at(VAD_MODEL_PATH);
- vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
- vad_config_path = model_path.at(VAD_CONFIG_PATH);
- }catch(const out_of_range& e){
- LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what();
- exit(0);
- }
- vad_handle = make_unique<FsmnVad>();
- vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path);
- }
-
- // AM model
- if(model_path.find(AM_MODEL_PATH) != model_path.end()){
- string am_model_path;
- string am_cmvn_path;
- string am_config_path;
-
- try{
- am_model_path = model_path.at(AM_MODEL_PATH);
- am_cmvn_path = model_path.at(AM_CMVN_PATH);
- am_config_path = model_path.at(AM_CONFIG_PATH);
- }catch(const out_of_range& e){
- LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
- exit(0);
- }
- InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
- }
-
- // PUNC model
- if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
- use_punc = true;
- string punc_model_path;
- string punc_config_path;
-
- try{
- punc_model_path = model_path.at(PUNC_MODEL_PATH);
- punc_config_path = model_path.at(PUNC_CONFIG_PATH);
- }catch(const out_of_range& e){
- LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
- exit(0);
- }
-
- punc_handle = make_unique<CTTransformer>();
- punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
- }
}
-void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// knf options
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
@@ -118,14 +65,6 @@
void Paraformer::Reset()
{
-}
-
-vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
- return vad_handle->Infer(pcm_data);
-}
-
-string Paraformer::AddPunc(const char* sz_input){
- return punc_handle->AddPunc(sz_input);
}
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -282,7 +221,7 @@
}
catch (std::exception const &e)
{
- printf(e.what());
+ LOG(ERROR)<<e.what();
}
return result;
@@ -291,12 +230,13 @@
string Paraformer::ForwardChunk(float* din, int len, int flag)
{
- printf("Not Imp!!!!!!\n");
- return "Hello";
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
}
string Paraformer::Rescoring()
{
- printf("Not Imp!!!!!!\n");
- return "Hello";
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
}
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index f3eb059..533c16f 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -2,16 +2,11 @@
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
-
#pragma once
-
-
-#ifndef PARAFORMER_MODELIMP_H
-#define PARAFORMER_MODELIMP_H
#include "precomp.h"
-namespace paraformer {
+namespace funasr {
class Paraformer : public Model {
/**
@@ -23,9 +18,6 @@
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
- std::unique_ptr<FsmnVad> vad_handle;
- std::unique_ptr<CTTransformer> punc_handle;
-
Vocab* vocab;
vector<float> means_list;
vector<float> vars_list;
@@ -36,7 +28,6 @@
void LoadCmvn(const char *filename);
vector<float> ApplyLfr(const vector<float> &in);
void ApplyCmvn(vector<float> *v);
-
string GreedySearch( float* in, int n_len, int64_t token_nums);
std::shared_ptr<Ort::Session> m_session;
@@ -46,23 +37,16 @@
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
- bool use_vad=false;
- bool use_punc=false;
public:
- Paraformer(std::map<std::string, std::string>& model_path, int thread_num=0);
+ Paraformer();
~Paraformer();
- void InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void Reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
string ForwardChunk(float* din, int len, int flag);
string Forward(float* din, int len, int flag);
string Rescoring();
- std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
- string AddPunc(const char* sz_input);
- bool UseVad(){return use_vad;};
- bool UsePunc(){return use_punc;};
};
-} // namespace paraformer
-#endif
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 1630e55..e607dbf 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -30,6 +30,10 @@
#include "com-define.h"
#include "commonfunc.h"
#include "predefine-coe.h"
+#include "model.h"
+#include "vad-model.h"
+#include "punc-model.h"
+#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "fsmn-vad.h"
@@ -39,9 +43,6 @@
#include "tensor.h"
#include "util.h"
#include "resample.h"
-#include "model.h"
-//#include "vad-model.h"
#include "paraformer.h"
-#include "libfunasrapi.h"
-
-using namespace paraformer;
+#include "offline-stream.h"
+#include "funasrruntime.h"
diff --git a/funasr/runtime/onnxruntime/src/predefine-coe.h b/funasr/runtime/onnxruntime/src/predefine-coe.h
index 93012d8..17c263f 100644
--- a/funasr/runtime/onnxruntime/src/predefine-coe.h
+++ b/funasr/runtime/onnxruntime/src/predefine-coe.h
@@ -3,6 +3,7 @@
#include <stdint.h>
+namespace funasr {
const int32_t melcoe_hex[] = {
0x3f01050c, 0x3e0afb11, 0x3f5d413c, 0x3f547fd0, 0x3e2e00c1, 0x3f132970,
@@ -590,3 +591,5 @@
0x39164323, 0x3910f3c6, 0x390bd472, 0x3906e374, 0x39021f2b, 0x38fb0c03,
0x38f22ce3, 0x38e99e04, 0x38e15c92, 0x38d965ce};
#endif
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/punc-model.cpp b/funasr/runtime/onnxruntime/src/punc-model.cpp
new file mode 100644
index 0000000..52ba0df
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/punc-model.cpp
@@ -0,0 +1,22 @@
+#include "precomp.h"
+
+namespace funasr {
+PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num)
+{
+ PuncModel *mm;
+ mm = new CTTransformer();
+
+ string punc_model_path;
+ string punc_config_path;
+
+ punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ }
+ punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME);
+
+ mm->InitPunc(punc_model_path, punc_config_path, thread_num);
+ return mm;
+}
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/resample.cc b/funasr/runtime/onnxruntime/src/resample.cpp
similarity index 99%
rename from funasr/runtime/onnxruntime/src/resample.cc
rename to funasr/runtime/onnxruntime/src/resample.cpp
index 0238752..9c74dc8 100644
--- a/funasr/runtime/onnxruntime/src/resample.cc
+++ b/funasr/runtime/onnxruntime/src/resample.cpp
@@ -31,6 +31,7 @@
#include <cstdlib>
#include <type_traits>
+namespace funasr {
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
@@ -303,3 +304,4 @@
}
}
}
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/resample.h b/funasr/runtime/onnxruntime/src/resample.h
index b9a283a..5cfc971 100644
--- a/funasr/runtime/onnxruntime/src/resample.h
+++ b/funasr/runtime/onnxruntime/src/resample.h
@@ -21,11 +21,11 @@
*/
// this file is copied and modified from
// kaldi/src/feat/resample.h
-
+#pragma once
#include <cstdint>
#include <vector>
-
+namespace funasr {
/*
We require that the input and output sampling rate be specified as
integers, as this is an easy way to specify that their ratio be rational.
@@ -135,3 +135,4 @@
std::vector<float> input_remainder_; ///< A small trailing part of the
///< previously seen input signal.
};
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/tensor.h b/funasr/runtime/onnxruntime/src/tensor.h
index 3b7a633..a2a7bc3 100644
--- a/funasr/runtime/onnxruntime/src/tensor.h
+++ b/funasr/runtime/onnxruntime/src/tensor.h
@@ -5,6 +5,8 @@
using namespace std;
+namespace funasr {
+
template <typename T> class Tensor {
private:
void alloc_buff();
@@ -152,4 +154,6 @@
fwrite(buff, 1, buff_size * sizeof(T), fp);
fclose(fp);
}
+
+} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index 5f29b46..a8f6301 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -5,12 +5,17 @@
#include "precomp.h"
+namespace funasr {
CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
{
OpenYaml(sz_yamlfile);
}
CTokenizer::CTokenizer():m_ready(false)
+{
+}
+
+CTokenizer::~CTokenizer()
{
}
@@ -216,3 +221,5 @@
}
id_out= String2Ids(str_out);
}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.h b/funasr/runtime/onnxruntime/src/tokenizer.h
index 4ff1809..419791b 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.h
+++ b/funasr/runtime/onnxruntime/src/tokenizer.h
@@ -6,6 +6,7 @@
#pragma once
#include <yaml-cpp/yaml.h>
+namespace funasr {
class CTokenizer {
private:
@@ -17,6 +18,7 @@
CTokenizer(const char* sz_yamlfile);
CTokenizer();
+ ~CTokenizer();
bool OpenYaml(const char* sz_yamlfile);
void ReadYaml(const YAML::Node& node);
vector<string> Id2String(vector<int> input);
@@ -30,3 +32,5 @@
void Tokenize(const char* str_info, vector<string>& str_out, vector<int>& id_out);
};
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/util.cpp b/funasr/runtime/onnxruntime/src/util.cpp
index c5c27af..d29c5c0 100644
--- a/funasr/runtime/onnxruntime/src/util.cpp
+++ b/funasr/runtime/onnxruntime/src/util.cpp
@@ -1,6 +1,7 @@
#include "precomp.h"
+namespace funasr {
float *LoadParams(const char *filename)
{
@@ -178,3 +179,5 @@
}
}
}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/util.h b/funasr/runtime/onnxruntime/src/util.h
index 6327f7b..95ef458 100644
--- a/funasr/runtime/onnxruntime/src/util.h
+++ b/funasr/runtime/onnxruntime/src/util.h
@@ -1,10 +1,9 @@
-
-
#ifndef UTIL_H
#define UTIL_H
using namespace std;
+namespace funasr {
extern float *LoadParams(const char *filename);
extern void SaveDataFile(const char *filename, void *data, uint32_t len);
@@ -27,4 +26,5 @@
string PathAppend(const string &p1, const string &p2);
+} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/vad-model.cpp b/funasr/runtime/onnxruntime/src/vad-model.cpp
new file mode 100644
index 0000000..764db00
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/vad-model.cpp
@@ -0,0 +1,24 @@
+#include "precomp.h"
+
+namespace funasr {
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
+{
+ VadModel *mm;
+ mm = new FsmnVad();
+
+ string vad_model_path;
+ string vad_cmvn_path;
+ string vad_config_path;
+
+ vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ }
+ vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), VAD_CMVN_NAME);
+ vad_config_path = PathAppend(model_path.at(MODEL_DIR), VAD_CONFIG_NAME);
+
+ mm->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ return mm;
+}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index 53233b3..65af8b6 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -10,6 +10,7 @@
using namespace std;
+namespace funasr {
Vocab::Vocab(const char *filename)
{
ifstream in(filename);
@@ -151,3 +152,5 @@
{
return vocab.size();
}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/vocab.h b/funasr/runtime/onnxruntime/src/vocab.h
index a3fdf65..6c4e523 100644
--- a/funasr/runtime/onnxruntime/src/vocab.h
+++ b/funasr/runtime/onnxruntime/src/vocab.h
@@ -7,6 +7,7 @@
#include <vector>
using namespace std;
+namespace funasr {
class Vocab {
private:
vector<string> vocab;
@@ -22,4 +23,5 @@
string Vector2StringV2(vector<int> in);
};
+} // namespace funasr
#endif
diff --git a/funasr/runtime/python/grpc/Readme.md b/funasr/runtime/python/grpc/Readme.md
index 895013a..742268b 100644
--- a/funasr/runtime/python/grpc/Readme.md
+++ b/funasr/runtime/python/grpc/Readme.md
@@ -1,4 +1,4 @@
-# Using funasr with grpc-python
+# Service with grpc-python
We can send streaming audio data to server in real-time with grpc client every 10 ms e.g., and get transcribed text when stop speaking.
The audio data is in streaming, the asr inference process is in offline.
diff --git a/funasr/runtime/python/grpc/proto/paraformer.proto b/funasr/runtime/python/grpc/proto/paraformer.proto
index b221ee2..6c336a8 100644
--- a/funasr/runtime/python/grpc/proto/paraformer.proto
+++ b/funasr/runtime/python/grpc/proto/paraformer.proto
@@ -1,19 +1,5 @@
-// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
syntax = "proto3";
-option java_package = "ex.grpc";
option objc_class_prefix = "paraformer";
package paraformer;
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index aeb91e7..6fd01e4 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -64,7 +64,7 @@
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int32')
data = {
"text": mini_sentence_id[None,:],
"text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
@@ -148,7 +148,7 @@
else:
precache = ""
cache = []
- full_text = precache + text
+ full_text = precache + " " + text
split_text = code_mix_split_words(full_text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
@@ -166,7 +166,7 @@
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
text_length = len(mini_sentence_id)
data = {
"input": mini_sentence_id[None,:],
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
index b5b3312..3cda80d 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -229,10 +229,11 @@
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
+ self.idx_pre_chunk = 0
self.max_time_out = False
self.decibel = []
- self.data_buf = None
- self.data_buf_all = None
+ self.data_buf_size = 0
+ self.data_buf_all_size = 0
self.waveform = None
self.ResetDetection()
@@ -259,10 +260,11 @@
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
+ self.idx_pre_chunk = 0
self.max_time_out = False
self.decibel = []
- self.data_buf = None
- self.data_buf_all = None
+ self.data_buf_size = 0
+ self.data_buf_all_size = 0
self.waveform = None
self.ResetDetection()
@@ -280,11 +282,11 @@
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
- if self.data_buf_all is None:
- self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
- self.data_buf = self.data_buf_all
+ if self.data_buf_all_size == 0:
+ self.data_buf_all_size = len(self.waveform[0])
+ self.data_buf_size = self.data_buf_all_size
else:
- self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0]))
+ self.data_buf_all_size += len(self.waveform[0])
for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
self.decibel.append(
10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \
@@ -294,17 +296,14 @@
# scores = self.encoder(feats, in_cache) # return B * T * D
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
- if self.scores is None:
- self.scores = scores # the first calculation
- else:
- self.scores = np.concatenate((self.scores, scores), axis=1)
+ self.scores=scores
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
while self.data_buf_start_frame < frame_idx:
- if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+ if self.data_buf_size >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
- self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
- self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+ self.data_buf_size = self.data_buf_all_size-self.data_buf_start_frame * int(
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
@@ -315,8 +314,8 @@
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
- expected_sample_number = max(expected_sample_number, len(self.data_buf))
- if len(self.data_buf) < expected_sample_number:
+ expected_sample_number = max(expected_sample_number, self.data_buf_size)
+ if self.data_buf_size < expected_sample_number:
print('error in calling pop data_buf\n')
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
@@ -334,10 +333,10 @@
data_to_pop = expected_sample_number
else:
data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
- if data_to_pop > len(self.data_buf):
- print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
- data_to_pop = len(self.data_buf)
- expected_sample_number = len(self.data_buf)
+ if data_to_pop > self.data_buf_size:
+ print('VAD data_to_pop is bigger than self.data_buf_size!!!\n')
+ data_to_pop = self.data_buf_size
+ expected_sample_number = self.data_buf_size
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
@@ -420,7 +419,7 @@
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(self.sil_pdf_ids) > 0:
assert len(self.scores) == 1 # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
- sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
+ sil_pdf_scores = [self.scores[0][t - self.idx_pre_chunk][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
@@ -502,7 +501,7 @@
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
-
+ self.idx_pre_chunk += self.scores.shape[1]
return 0
def DetectLastFrames(self) -> int:
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 06603f0..0b249dd 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.0.6'
+VERSION_NUM = '0.0.8'
setuptools.setup(
name=MODULE_NAME,
diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md
index 7191ed0..7ca5730 100644
--- a/funasr/runtime/python/websocket/README.md
+++ b/funasr/runtime/python/websocket/README.md
@@ -1,6 +1,6 @@
# Service with websocket-python
-This is a demo using funasr pipeline with websocket python-api.
+This is a demo using funasr pipeline with websocket python-api. It supports the offline, online, offline/online-2pass unifying speech recognition.
## For the Server
@@ -22,25 +22,49 @@
### Start server
#### ASR offline server
-
-[//]: # (```shell)
-
-[//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-
-[//]: # (```)
-#### ASR streaming server
+##### API-reference
```shell
-python ws_server_online.py --host "0.0.0.0" --port 10095
+python ws_server_offline.py \
+--port [port id] \
+--asr_model [asr model_name] \
+--punc_model [punc model_name] \
+--ngpu [0 or 1] \
+--ncpu [1 or 4]
```
-####
+##### Usage examples
+```shell
+python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+```
+
+#### ASR streaming server
+##### API-reference
+```shell
+python ws_server_online.py \
+--port [port id] \
+--asr_model_online [asr model_name] \
+--ngpu [0 or 1] \
+--ncpu [1 or 4]
+```
+##### Usage examples
+```shell
+python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+```
#### ASR offline/online 2pass server
-
-[//]: # (```shell)
-
-[//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-
-[//]: # (```)
+##### API-reference
+```shell
+python ws_server_2pass.py \
+--port [port id] \
+--asr_model [asr model_name] \
+--asr_model_online [asr model_name] \
+--punc_model [punc model_name] \
+--ngpu [0 or 1] \
+--ncpu [1 or 4]
+```
+##### Usage examples
+```shell
+python ws_server_2pass.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+```
## For the client
@@ -51,13 +75,56 @@
pip install -r requirements_client.txt
```
-Start client
-
+### Start client
+#### API-reference
```shell
-# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python ws_client.py --host "127.0.0.1" --port 10096 --chunk_size "5,10,5"
+python ws_client.py \
+--host [ip_address] \
+--port [port id] \
+--chunk_size ["5,10,5"=600ms, "8,8,4"=480ms] \
+--chunk_interval [duration of send chunk_size/chunk_interval] \
+--words_max_print [max number of words to print] \
+--audio_in [if set, loadding from wav.scp, else recording from mircrophone] \
+--output_dir [if set, write the results to output_dir] \
+--send_without_sleep [only set for offline]
+```
+#### Usage examples
+##### ASR offline client
+Recording from mircrophone
+```shell
+# --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
+python ws_client.py --host "0.0.0.0" --port 10095 --chunk_interval 10 --words_max_print 100
+```
+Loadding from wav.scp(kaldi style)
+```shell
+# --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
+python ws_client.py --host "0.0.0.0" --port 10095 --chunk_interval 10 --words_max_print 100 --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results"
```
+##### ASR streaming client
+Recording from mircrophone
+```shell
+# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
+python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "5,10,5" --words_max_print 100
+```
+Loadding from wav.scp(kaldi style)
+```shell
+# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
+python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "5,10,5" --audio_in "./data/wav.scp" --words_max_print 100 --output_dir "./results"
+```
+
+##### ASR offline/online 2pass client
+Recording from mircrophone
+```shell
+# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
+python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "8,8,4" --words_max_print 10000
+```
+Loadding from wav.scp(kaldi style)
+```shell
+# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
+python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "8,8,4" --audio_in "./data/wav.scp" --words_max_print 10000 --output_dir "./results"
+```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
-2. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service.
+2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
+3. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service of offline model.
diff --git a/funasr/runtime/python/websocket/parse_args.py b/funasr/runtime/python/websocket/parse_args.py
index 2528a76..d170be8 100644
--- a/funasr/runtime/python/websocket/parse_args.py
+++ b/funasr/runtime/python/websocket/parse_args.py
@@ -31,5 +31,10 @@
type=int,
default=1,
help="0 for cpu, 1 for gpu")
+parser.add_argument("--ncpu",
+ type=int,
+ default=1,
+ help="cpu cores")
-args = parser.parse_args()
\ No newline at end of file
+args = parser.parse_args()
+print(args)
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
index 8bbf103..a4a6d9f 100644
--- a/funasr/runtime/python/websocket/ws_client.py
+++ b/funasr/runtime/python/websocket/ws_client.py
@@ -6,6 +6,13 @@
# import threading
import argparse
import json
+import traceback
+from multiprocessing import Process
+from funasr.fileio.datadir_writer import DatadirWriter
+
+import logging
+
+logging.basicConfig(level=logging.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
@@ -30,15 +37,35 @@
type=str,
default=None,
help="audio_in")
+parser.add_argument("--send_without_sleep",
+ action="store_true",
+ default=False,
+ help="if audio_in is set, send_without_sleep")
+parser.add_argument("--test_thread_num",
+ type=int,
+ default=1,
+ help="test_thread_num")
+parser.add_argument("--words_max_print",
+ type=int,
+ default=100,
+ help="chunk")
+parser.add_argument("--output_dir",
+ type=str,
+ default=None,
+ help="output_dir")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
-
+print(args)
# voices = asyncio.Queue()
from queue import Queue
voices = Queue()
-# 鍏朵粬鍑芥暟鍙互閫氳繃璋冪敤send(data)鏉ュ彂閫佹暟鎹紝渚嬪锛�
+ibest_writer = None
+if args.output_dir is not None:
+ writer = DatadirWriter(args.output_dir)
+ ibest_writer = writer[f"1best_recog"]
+
async def record_microphone():
is_finished = False
import pyaudio
@@ -65,11 +92,9 @@
message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished})
voices.put(message)
- #print(voices.qsize())
await asyncio.sleep(0.005)
-# 鍏朵粬鍑芥暟鍙互閫氳繃璋冪敤send(data)鏉ュ彂閫佹暟鎹紝渚嬪锛�
async def record_from_scp():
import wave
global voices
@@ -81,19 +106,17 @@
wavs = [args.audio_in]
for wav in wavs:
wav_splits = wav.strip().split()
+ wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
+
# bytes_f = open(wav_path, "rb")
# bytes_data = bytes_f.read()
with wave.open(wav_path, "rb") as wav_file:
- # 鑾峰彇闊抽鍙傛暟
params = wav_file.getparams()
- # 鑾峰彇澶翠俊鎭殑闀垮害
# header_length = wav_file.getheaders()[0][1]
- # 璇诲彇闊抽甯ф暟鎹紝璺宠繃澶翠俊鎭�
# wav_file.setpos(header_length)
frames = wav_file.readframes(wav_file.getnframes())
- # 灏嗛煶棰戝抚鏁版嵁杞崲涓哄瓧鑺傜被鍨嬬殑鏁版嵁
audio_bytes = bytes(frames)
# stride = int(args.chunk_size/1000*16000*2)
stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
@@ -106,12 +129,12 @@
beg = i*stride
data = audio_bytes[beg:beg+stride]
data = data.decode('ISO-8859-1')
- message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished})
+ message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished, "wav_name": wav_name})
voices.put(message)
# print("data_chunk: ", len(data_chunk))
# print(voices.qsize())
-
- await asyncio.sleep(60*args.chunk_size[1]/args.chunk_interval/1000)
+ sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
+ await asyncio.sleep(sleep_duration)
is_finished = True
message = json.dumps({"is_finished": is_finished})
@@ -126,31 +149,57 @@
data = voices.get()
voices.task_done()
try:
- await websocket.send(data) # 閫氳繃ws瀵硅薄鍙戦�佹暟鎹�
+ await websocket.send(data)
except Exception as e:
print('Exception occurred:', e)
+ traceback.print_exc()
+ exit(0)
await asyncio.sleep(0.005)
await asyncio.sleep(0.005)
-async def message():
+async def message(id):
global websocket
text_print = ""
+ text_print_2pass_online = ""
+ text_print_2pass_offline = ""
while True:
try:
meg = await websocket.recv()
meg = json.loads(meg)
- # print(meg, end = '')
- # print("\r")
- text = meg["text"][0]
- text_print += text
- text_print = text_print[-55:]
- os.system('clear')
- print("\r"+text_print)
+ wav_name = meg.get("wav_name", "demo")
+ # print(wav_name)
+ text = meg["text"]
+ if ibest_writer is not None:
+ ibest_writer["text"][wav_name] = text
+
+ if meg["mode"] == "online":
+ text_print += " {}".format(text)
+ text_print = text_print[-args.words_max_print:]
+ os.system('clear')
+ print("\rpid"+str(id)+": "+text_print)
+ elif meg["mode"] == "online":
+ text_print += "{}".format(text)
+ text_print = text_print[-args.words_max_print:]
+ os.system('clear')
+ print("\rpid"+str(id)+": "+text_print)
+ else:
+ if meg["mode"] == "2pass-online":
+ text_print_2pass_online += " {}".format(text)
+ text_print = text_print_2pass_offline + text_print_2pass_online
+ else:
+ text_print_2pass_online = ""
+ text_print = text_print_2pass_offline + "{}".format(text)
+ text_print_2pass_offline += "{}".format(text)
+ text_print = text_print[-args.words_max_print:]
+ os.system('clear')
+ print("\rpid" + str(id) + ": " + text_print)
+
except Exception as e:
print("Exception:", e)
-
+ traceback.print_exc()
+ exit(0)
async def print_messge():
global websocket
@@ -161,22 +210,36 @@
print(meg)
except Exception as e:
print("Exception:", e)
+ traceback.print_exc()
+ exit(0)
-
-async def ws_client():
- global websocket # 瀹氫箟涓�涓叏灞�鍙橀噺ws锛岀敤浜庝繚瀛榳ebsocket杩炴帴瀵硅薄
- # uri = "ws://11.167.134.197:8899"
+async def ws_client(id):
+ global websocket
uri = "ws://{}:{}".format(args.host, args.port)
- #ws = await websockets.connect(uri, subprotocols=["binary"]) # 鍒涘缓涓�涓暱杩炴帴
async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
if args.audio_in is not None:
- task = asyncio.create_task(record_from_scp()) # 鍒涘缓涓�涓悗鍙颁换鍔″綍闊�
+ task = asyncio.create_task(record_from_scp())
else:
- task = asyncio.create_task(record_microphone()) # 鍒涘缓涓�涓悗鍙颁换鍔″綍闊�
- task2 = asyncio.create_task(ws_send()) # 鍒涘缓涓�涓悗鍙颁换鍔″彂閫�
- task3 = asyncio.create_task(message()) # 鍒涘缓涓�涓悗鍙版帴鏀舵秷鎭殑浠诲姟
+ task = asyncio.create_task(record_microphone())
+ task2 = asyncio.create_task(ws_send())
+ task3 = asyncio.create_task(message(id))
await asyncio.gather(task, task2, task3)
+def one_thread(id):
+ asyncio.get_event_loop().run_until_complete(ws_client(id))
+ asyncio.get_event_loop().run_forever()
-asyncio.get_event_loop().run_until_complete(ws_client()) # 鍚姩鍗忕▼
-asyncio.get_event_loop().run_forever()
+
+if __name__ == '__main__':
+ process_list = []
+ for i in range(args.test_thread_num):
+ p = Process(target=one_thread,args=(i,))
+ p.start()
+ process_list.append(p)
+
+ for i in process_list:
+ p.join()
+
+ print('end')
+
+
diff --git a/funasr/runtime/python/websocket/ws_server_2pass.py b/funasr/runtime/python/websocket/ws_server_2pass.py
new file mode 100644
index 0000000..ced67ff
--- /dev/null
+++ b/funasr/runtime/python/websocket/ws_server_2pass.py
@@ -0,0 +1,182 @@
+import asyncio
+import json
+import websockets
+import time
+import logging
+import tracemalloc
+import numpy as np
+
+from parse_args import args
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
+
+tracemalloc.start()
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+
+websocket_users = set()
+
+print("model loading")
+# asr
+inference_pipeline_asr = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.asr_model,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ model_revision=None)
+
+
+# vad
+inference_pipeline_vad = pipeline(
+ task=Tasks.voice_activity_detection,
+ model=args.vad_model,
+ model_revision=None,
+ output_dir=None,
+ batch_size=1,
+ mode='online',
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+)
+
+if args.punc_model != "":
+ inference_pipeline_punc = pipeline(
+ task=Tasks.punctuation,
+ model=args.punc_model,
+ model_revision=None,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ )
+else:
+ inference_pipeline_punc = None
+
+inference_pipeline_asr_online = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.asr_model_online,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ model_revision='v1.0.4')
+
+print("model loaded")
+
+async def ws_serve(websocket, path):
+ frames = []
+ frames_asr = []
+ frames_asr_online = []
+ global websocket_users
+ websocket_users.add(websocket)
+ websocket.param_dict_asr = {}
+ websocket.param_dict_asr_online = {"cache": dict()}
+ websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
+ websocket.param_dict_punc = {'cache': list()}
+ websocket.vad_pre_idx = 0
+ speech_start = False
+
+ try:
+ async for message in websocket:
+ message = json.loads(message)
+ is_finished = message["is_finished"]
+ if not is_finished:
+ audio = bytes(message['audio'], 'ISO-8859-1')
+ frames.append(audio)
+ duration_ms = len(audio)//32
+ websocket.vad_pre_idx += duration_ms
+
+ is_speaking = message["is_speaking"]
+ websocket.param_dict_vad["is_final"] = not is_speaking
+ websocket.param_dict_asr_online["is_final"] = not is_speaking
+ websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
+ websocket.wav_name = message.get("wav_name", "demo")
+ # asr online
+ frames_asr_online.append(audio)
+ if len(frames_asr_online) % message["chunk_interval"] == 0:
+ audio_in = b"".join(frames_asr_online)
+ await async_asr_online(websocket, audio_in)
+ frames_asr_online = []
+ if speech_start:
+ frames_asr.append(audio)
+ # vad online
+ speech_start_i, speech_end_i = await async_vad(websocket, audio)
+ if speech_start_i:
+ speech_start = True
+ beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
+ frames_pre = frames[-beg_bias:]
+ frames_asr = []
+ frames_asr.extend(frames_pre)
+ # asr punc offline
+ if speech_end_i or not is_speaking:
+ audio_in = b"".join(frames_asr)
+ await async_asr(websocket, audio_in)
+ frames_asr = []
+ speech_start = False
+ frames_asr_online = []
+ websocket.param_dict_asr_online = {"cache": dict()}
+ if not is_speaking:
+ websocket.vad_pre_idx = 0
+ frames = []
+ websocket.param_dict_vad = {'in_cache': dict()}
+ else:
+ frames = frames[-20:]
+
+
+ except websockets.ConnectionClosed:
+ print("ConnectionClosed...", websocket_users)
+ websocket_users.remove(websocket)
+ except websockets.InvalidState:
+ print("InvalidState...")
+ except Exception as e:
+ print("Exception:", e)
+
+
+async def async_vad(websocket, audio_in):
+
+ segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+
+ speech_start = False
+ speech_end = False
+
+ if len(segments_result) == 0 or len(segments_result["text"]) > 1:
+ return speech_start, speech_end
+ if segments_result["text"][0][0] != -1:
+ speech_start = segments_result["text"][0][0]
+ if segments_result["text"][0][1] != -1:
+ speech_end = True
+ return speech_start, speech_end
+
+
+async def async_asr(websocket, audio_in):
+ if len(audio_in) > 0:
+ # print(len(audio_in))
+ audio_in = load_bytes(audio_in)
+
+ rec_result = inference_pipeline_asr(audio_in=audio_in,
+ param_dict=websocket.param_dict_asr)
+ # print(rec_result)
+ if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
+ rec_result = inference_pipeline_punc(text_in=rec_result['text'],
+ param_dict=websocket.param_dict_punc)
+ # print("offline", rec_result)
+ message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
+ await websocket.send(message)
+
+
+async def async_asr_online(websocket, audio_in):
+ if len(audio_in) > 0:
+ audio_in = load_bytes(audio_in)
+ rec_result = inference_pipeline_asr_online(audio_in=audio_in,
+ param_dict=websocket.param_dict_asr_online)
+ if websocket.param_dict_asr_online["is_final"]:
+ websocket.param_dict_asr_online["cache"] = dict()
+ if "text" in rec_result:
+ if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
+ # print("online", rec_result)
+ message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
+ await websocket.send(message)
+
+
+start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+asyncio.get_event_loop().run_until_complete(start_server)
+asyncio.get_event_loop().run_forever()
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_server_offline.py b/funasr/runtime/python/websocket/ws_server_offline.py
new file mode 100644
index 0000000..15578f6
--- /dev/null
+++ b/funasr/runtime/python/websocket/ws_server_offline.py
@@ -0,0 +1,150 @@
+import asyncio
+import json
+import websockets
+import time
+import logging
+import tracemalloc
+import numpy as np
+
+from parse_args import args
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
+
+tracemalloc.start()
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+
+websocket_users = set()
+
+print("model loading")
+# asr
+inference_pipeline_asr = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.asr_model,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ model_revision=None)
+
+
+# vad
+inference_pipeline_vad = pipeline(
+ task=Tasks.voice_activity_detection,
+ model=args.vad_model,
+ model_revision=None,
+ output_dir=None,
+ batch_size=1,
+ mode='online',
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+)
+
+if args.punc_model != "":
+ inference_pipeline_punc = pipeline(
+ task=Tasks.punctuation,
+ model=args.punc_model,
+ model_revision=None,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ )
+else:
+ inference_pipeline_punc = None
+
+print("model loaded")
+
+async def ws_serve(websocket, path):
+ frames = []
+ frames_asr = []
+ global websocket_users
+ websocket_users.add(websocket)
+ websocket.param_dict_asr = {}
+ websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
+ websocket.param_dict_punc = {'cache': list()}
+ websocket.vad_pre_idx = 0
+ speech_start = False
+
+ try:
+ async for message in websocket:
+ message = json.loads(message)
+ is_finished = message["is_finished"]
+ if not is_finished:
+ audio = bytes(message['audio'], 'ISO-8859-1')
+ frames.append(audio)
+ duration_ms = len(audio)//32
+ websocket.vad_pre_idx += duration_ms
+
+ is_speaking = message["is_speaking"]
+ websocket.param_dict_vad["is_final"] = not is_speaking
+ websocket.wav_name = message.get("wav_name", "demo")
+ if speech_start:
+ frames_asr.append(audio)
+ speech_start_i, speech_end_i = await async_vad(websocket, audio)
+ if speech_start_i:
+ speech_start = True
+ beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
+ frames_pre = frames[-beg_bias:]
+ frames_asr = []
+ frames_asr.extend(frames_pre)
+ if speech_end_i or not is_speaking:
+ audio_in = b"".join(frames_asr)
+ await async_asr(websocket, audio_in)
+ frames_asr = []
+ speech_start = False
+ if not is_speaking:
+ websocket.vad_pre_idx = 0
+ frames = []
+ websocket.param_dict_vad = {'in_cache': dict()}
+ else:
+ frames = frames[-20:]
+
+
+ except websockets.ConnectionClosed:
+ print("ConnectionClosed...", websocket_users)
+ websocket_users.remove(websocket)
+ except websockets.InvalidState:
+ print("InvalidState...")
+ except Exception as e:
+ print("Exception:", e)
+
+
+async def async_vad(websocket, audio_in):
+
+ segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+
+ speech_start = False
+ speech_end = False
+
+ if len(segments_result) == 0 or len(segments_result["text"]) > 1:
+ return speech_start, speech_end
+ if segments_result["text"][0][0] != -1:
+ speech_start = segments_result["text"][0][0]
+ if segments_result["text"][0][1] != -1:
+ speech_end = True
+ return speech_start, speech_end
+
+
+async def async_asr(websocket, audio_in):
+ if len(audio_in) > 0:
+ # print(len(audio_in))
+ audio_in = load_bytes(audio_in)
+
+ rec_result = inference_pipeline_asr(audio_in=audio_in,
+ param_dict=websocket.param_dict_asr)
+ # print(rec_result)
+ if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
+ rec_result = inference_pipeline_punc(text_in=rec_result['text'],
+ param_dict=websocket.param_dict_punc)
+ # print(rec_result)
+ message = json.dumps({"mode": "offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
+ await websocket.send(message)
+
+
+
+
+
+start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+asyncio.get_event_loop().run_until_complete(start_server)
+asyncio.get_event_loop().run_forever()
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py
index 7ef0e21..3c0fb16 100644
--- a/funasr/runtime/python/websocket/ws_server_online.py
+++ b/funasr/runtime/python/websocket/ws_server_online.py
@@ -12,7 +12,7 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
-from funasr_onnx.utils.frontend import load_bytes
+from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
tracemalloc.start()
@@ -28,6 +28,8 @@
inference_pipeline_asr_online = pipeline(
task=Tasks.auto_speech_recognition,
model=args.asr_model_online,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
model_revision='v1.0.4')
print("model loaded")
@@ -35,14 +37,10 @@
async def ws_serve(websocket, path):
- frames_online = []
+ frames_asr_online = []
global websocket_users
- websocket.send_msg = Queue()
websocket_users.add(websocket)
websocket.param_dict_asr_online = {"cache": dict()}
- websocket.speek_online = Queue()
- ss_online = threading.Thread(target=asr_online, args=(websocket,))
- ss_online.start()
try:
async for message in websocket:
@@ -53,54 +51,37 @@
is_speaking = message["is_speaking"]
websocket.param_dict_asr_online["is_final"] = not is_speaking
-
+ websocket.wav_name = message.get("wav_name", "demo")
websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
-
- frames_online.append(audio)
-
- if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking:
-
- audio_in = b"".join(frames_online)
- websocket.speek_online.put(audio_in)
- frames_online = []
+ frames_asr_online.append(audio)
+ if len(frames_asr_online) % message["chunk_interval"] == 0 or not is_speaking:
+ audio_in = b"".join(frames_asr_online)
+ await async_asr_online(websocket,audio_in)
+ frames_asr_online = []
- if not websocket.send_msg.empty():
- await websocket.send(websocket.send_msg.get())
- websocket.send_msg.task_done()
except websockets.ConnectionClosed:
- print("ConnectionClosed...", websocket_users) # 閾炬帴鏂紑
+ print("ConnectionClosed...", websocket_users)
websocket_users.remove(websocket)
except websockets.InvalidState:
- print("InvalidState...") # 鏃犳晥鐘舵��
+ print("InvalidState...")
except Exception as e:
print("Exception:", e)
-
-
-def asr_online(websocket): # ASR鎺ㄧ悊
- global websocket_users
- while websocket in websocket_users:
- if not websocket.speek_online.empty():
- audio_in = websocket.speek_online.get()
- websocket.speek_online.task_done()
+async def async_asr_online(websocket,audio_in):
if len(audio_in) > 0:
- # print(len(audio_in))
audio_in = load_bytes(audio_in)
rec_result = inference_pipeline_asr_online(audio_in=audio_in,
param_dict=websocket.param_dict_asr_online)
if websocket.param_dict_asr_online["is_final"]:
websocket.param_dict_asr_online["cache"] = dict()
-
if "text" in rec_result:
if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
- print(rec_result["text"])
- message = json.dumps({"mode": "online", "text": rec_result["text"]})
- websocket.send_msg.put(message)
-
- time.sleep(0.005)
+ message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
+ await websocket.send(message)
+
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
new file mode 100644
index 0000000..e89537b
--- /dev/null
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -0,0 +1,64 @@
+cmake_minimum_required(VERSION 3.10)
+
+project(FunASRWebscoket)
+
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
+
+option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
+
+if(ENABLE_WEBSOCKET)
+ # cmake_policy(SET CMP0135 NEW)
+
+ include(FetchContent)
+ FetchContent_Declare(websocketpp
+ GIT_REPOSITORY https://github.com/zaphoyd/websocketpp.git
+ GIT_TAG 0.8.2
+ SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/websocket
+ )
+
+ FetchContent_MakeAvailable(websocketpp)
+ include_directories(${PROJECT_SOURCE_DIR}/third_party/websocket)
+
+
+ FetchContent_Declare(asio
+ URL https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
+ SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
+ )
+
+ FetchContent_MakeAvailable(asio)
+ include_directories(${PROJECT_SOURCE_DIR}/third_party/asio/asio/include)
+
+ FetchContent_Declare(json
+ URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
+ SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
+ )
+
+ FetchContent_MakeAvailable(json)
+ include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
+
+
+
+endif()
+
+# Include generated *.pb.h files
+link_directories(${ONNXRUNTIME_DIR}/lib)
+
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
+
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
+
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
+set(BUILD_TESTING OFF)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
+
+
+add_executable(websocketmain "websocketmain.cpp" "websocketsrv.cpp")
+add_executable(websocketclient "websocketclient.cpp")
+
+target_link_libraries(websocketclient PUBLIC funasr)
+target_link_libraries(websocketmain PUBLIC funasr)
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
new file mode 100644
index 0000000..078184e
--- /dev/null
+++ b/funasr/runtime/websocket/readme.md
@@ -0,0 +1,99 @@
+# Service with websocket-cpp
+
+## Export the model
+### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
+
+```shell
+# pip3 install torch torchaudio
+pip install -U modelscope funasr
+# For the users in China, you could install with the command:
+# pip install -U modelscope funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
+### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
+
+```shell
+python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
+```
+
+## Building for Linux/Unix
+
+### Download onnxruntime
+```shell
+# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
+# here we get a copy of onnxruntime for linux 64
+wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
+tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
+```
+
+### Install openblas
+```shell
+sudo apt-get install libopenblas-dev #ubuntu
+# sudo yum -y install openblas-devel #centos
+```
+
+### Build runtime
+```shell
+git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/websocket
+mkdir build && cd build
+cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
+make
+```
+## Run the websocket server
+
+```shell
+cd bin
+./websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
+ [--io_thread_num <int>] [--port <int>] [--listen_ip
+ <string>] [--punc-quant <string>] [--punc-dir <string>]
+ [--vad-quant <string>] [--vad-dir <string>] [--quantize
+ <string>] --model-dir <string> [--] [--version] [-h]
+Where:
+ --model-dir <string>
+ (required) the asr model path, which contains model.onnx, config.yaml, am.mvn
+ --quantize <string>
+ false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+
+ --vad-dir <string>
+ the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+ --vad-quant <string>
+ false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+
+ --punc-dir <string>
+ the punc model path, which contains model.onnx, punc.yaml
+ --punc-quant <string>
+ false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
+
+ --decoder_thread_num <int>
+ number of threads for decoder, default:8
+ --io_thread_num <int>
+ number of threads for network io, default:8
+ --port <int>
+ listen port, default:8889
+
+ Required: --model-dir <string>
+ If use vad, please add: --vad-dir <string>
+ If use punc, please add: --punc-dir <string>
+example:
+ websocketmain --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+```
+
+## Run websocket client test
+
+```shell
+Usage: websocketclient server_ip port wav_path threads_num
+
+example:
+
+websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64
+
+result json, example like:
+{"text":"涓�浜屼笁鍥涗簲鍏竷鍏節鍗佷竴浜屼笁鍥涗簲鍏竷鍏節鍗�"}
+```
+
+
+## Acknowledge
+1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
+2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/add-offline-websocket-srv/funasr/runtime/websocket) for contributing the websocket(cpp-api).
+
+
diff --git a/funasr/runtime/websocket/websocketclient.cpp b/funasr/runtime/websocket/websocketclient.cpp
new file mode 100644
index 0000000..3ab4e99
--- /dev/null
+++ b/funasr/runtime/websocket/websocketclient.cpp
@@ -0,0 +1,221 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// client for websocket, support multiple threads
+// Usage: websocketclient server_ip port wav_path threads_num
+
+#define ASIO_STANDALONE 1
+#include <websocketpp/client.hpp>
+#include <websocketpp/common/thread.hpp>
+#include <websocketpp/config/asio_no_tls_client.hpp>
+
+#include "audio.h"
+
+/**
+ * Define a semi-cross platform helper method that waits/sleeps for a bit.
+ */
+void wait_a_bit() {
+#ifdef WIN32
+ Sleep(1000);
+#else
+ sleep(1);
+#endif
+}
+typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
+
+class websocket_client {
+ public:
+ typedef websocketpp::client<websocketpp::config::asio_client> client;
+ typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
+
+ websocket_client() : m_open(false), m_done(false) {
+ // set up access channels to only log interesting things
+ m_client.clear_access_channels(websocketpp::log::alevel::all);
+ m_client.set_access_channels(websocketpp::log::alevel::connect);
+ m_client.set_access_channels(websocketpp::log::alevel::disconnect);
+ m_client.set_access_channels(websocketpp::log::alevel::app);
+
+ // Initialize the Asio transport policy
+ m_client.init_asio();
+
+ // Bind the handlers we are using
+ using websocketpp::lib::bind;
+ using websocketpp::lib::placeholders::_1;
+ m_client.set_open_handler(bind(&websocket_client::on_open, this, _1));
+ m_client.set_close_handler(bind(&websocket_client::on_close, this, _1));
+ m_client.set_close_handler(bind(&websocket_client::on_close, this, _1));
+
+ m_client.set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+
+ m_client.set_fail_handler(bind(&websocket_client::on_fail, this, _1));
+ m_client.clear_access_channels(websocketpp::log::alevel::all);
+ }
+ void on_message(websocketpp::connection_hdl hdl, message_ptr msg) {
+ const std::string& payload = msg->get_payload();
+ switch (msg->get_opcode()) {
+ case websocketpp::frame::opcode::text:
+ std::cout << "on_message=" << payload << std::endl;
+ }
+ }
+ // This method will block until the connection is complete
+ void run(const std::string& uri, const std::string& wav_path) {
+ // Create a new connection to the given URI
+ websocketpp::lib::error_code ec;
+ client::connection_ptr con = m_client.get_connection(uri, ec);
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Get Connection Error: " + ec.message());
+ return;
+ }
+ this->wav_path = std::move(wav_path);
+ // Grab a handle for this connection so we can talk to it in a thread
+ // safe manor after the event loop starts.
+ m_hdl = con->get_handle();
+
+ // Queue the connection. No DNS queries or network connections will be
+ // made until the io_service event loop is run.
+ m_client.connect(con);
+
+ // Create a thread to run the ASIO io_service event loop
+ websocketpp::lib::thread asio_thread(&client::run, &m_client);
+
+ send_wav_data();
+ asio_thread.join();
+ }
+
+ // The open handler will signal that we are ready to start sending data
+ void on_open(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection opened, starting data!");
+
+ scoped_lock guard(m_lock);
+ m_open = true;
+ }
+
+ // The close handler will signal that we should stop sending data
+ void on_close(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection closed, stopping data!");
+
+ scoped_lock guard(m_lock);
+ m_done = true;
+ }
+
+ // The fail handler will signal that we should stop sending data
+ void on_fail(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection failed, stopping data!");
+
+ scoped_lock guard(m_lock);
+ m_done = true;
+ }
+ // send wav to server
+ void send_wav_data() {
+ uint64_t count = 0;
+ std::stringstream val;
+
+ funasr::Audio audio(1);
+ int32_t sampling_rate = 16000;
+
+ if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate)) {
+ std::cout << "error in load wav" << std::endl;
+ return;
+ }
+
+ float* buff;
+ int len;
+ int flag = 0;
+ bool wait = false;
+ while (1) {
+ {
+ scoped_lock guard(m_lock);
+ // If the connection has been closed, stop generating data
+ if (m_done) {
+ break;
+ }
+
+ // If the connection hasn't been opened yet wait a bit and retry
+ if (!m_open) {
+ wait = true;
+ } else {
+ break;
+ }
+ }
+
+ if (wait) {
+ std::cout << "wait.." << m_open << std::endl;
+ wait_a_bit();
+
+ continue;
+ }
+ }
+ websocketpp::lib::error_code ec;
+ // fetch wav data use asr engine api
+ while (audio.Fetch(buff, len, flag) > 0) {
+ short iArray[len];
+
+ // convert float -1,1 to short -32768,32767
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buff[i] * 32767);
+ }
+ // send data to server
+ m_client.send(m_hdl, iArray, len * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ std::cout << "sended data len=" << len * sizeof(short) << std::endl;
+ // The most likely error that we will get is that the connection is
+ // not in the right state. Usually this means we tried to send a
+ // message to a connection that was closed or in the process of
+ // closing. While many errors here can be easily recovered from,
+ // in this simple example, we'll stop the data loop.
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ break;
+ }
+
+ wait_a_bit();
+ }
+
+ m_client.send(m_hdl, "Done", websocketpp::frame::opcode::text, ec);
+ wait_a_bit();
+ }
+
+ private:
+ client m_client;
+ websocketpp::connection_hdl m_hdl;
+ websocketpp::lib::mutex m_lock;
+ std::string wav_path;
+ bool m_open;
+ bool m_done;
+};
+
+int main(int argc, char* argv[]) {
+ if (argc < 5) {
+ printf("Usage: %s server_ip port wav_path threads_num\n", argv[0]);
+ exit(-1);
+ }
+ std::string server_ip = argv[1];
+ std::string port = argv[2];
+ std::string wav_path = argv[3];
+ int threads_num = atoi(argv[4]);
+ std::vector<websocketpp::lib::thread> client_threads;
+
+ std::string uri = "ws://" + server_ip + ":" + port;
+
+ for (size_t i = 0; i < threads_num; i++) {
+ client_threads.emplace_back([uri, wav_path]() {
+ websocket_client c;
+ c.run(uri, wav_path);
+ });
+ }
+
+ for (auto& t : client_threads) {
+ t.join();
+ }
+}
\ No newline at end of file
diff --git a/funasr/runtime/websocket/websocketmain.cpp b/funasr/runtime/websocket/websocketmain.cpp
new file mode 100644
index 0000000..4614b51
--- /dev/null
+++ b/funasr/runtime/websocket/websocketmain.cpp
@@ -0,0 +1,149 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// io server
+// Usage:websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
+// [--io_thread_num <int>] [--port <int>] [--listen_ip
+// <string>] [--punc-quant <string>] [--punc-dir <string>]
+// [--vad-quant <string>] [--vad-dir <string>] [--quantize
+// <string>] --model-dir <string> [--] [--version] [-h]
+#include "websocketsrv.h"
+
+using namespace std;
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
+ std::map<std::string, std::string>& model_path) {
+ if (value_arg.isSet()) {
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO) << key << " : " << value_arg.getValue();
+ }
+}
+int main(int argc, char* argv[]) {
+ try {
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("websocketmain", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir(
+ "", MODEL_DIR,
+ "the asr model path, which contains model.onnx, config.yaml, am.mvn",
+ true, "", "string");
+ TCLAP::ValueArg<std::string> quantize(
+ "", QUANTIZE,
+ "false (Default), load the model of model.onnx in model_dir. If set "
+ "true, load the model of model_quant.onnx in model_dir",
+ false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir(
+ "", VAD_DIR,
+ "the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
+ false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant(
+ "", VAD_QUANT,
+ "false (Default), load the model of model.onnx in vad_dir. If set "
+ "true, load the model of model_quant.onnx in vad_dir",
+ false, "false", "string");
+ TCLAP::ValueArg<std::string> punc_dir(
+ "", PUNC_DIR,
+ "the punc model path, which contains model.onnx, punc.yaml", false, "",
+ "string");
+ TCLAP::ValueArg<std::string> punc_quant(
+ "", PUNC_QUANT,
+ "false (Default), load the model of model.onnx in punc_dir. If set "
+ "true, load the model of model_quant.onnx in punc_dir",
+ false, "false", "string");
+
+ TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false,
+ "0.0.0.0", "string");
+ TCLAP::ValueArg<int> port("", "port", "port", false, 8889, "int");
+ TCLAP::ValueArg<int> io_thread_num("", "io_thread_num", "io_thread_num",
+ false, 8, "int");
+ TCLAP::ValueArg<int> decoder_thread_num(
+ "", "decoder_thread_num", "decoder_thread_num", false, 8, "int");
+ TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
+ "model_thread_num", false, 1, "int");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
+
+ cmd.add(listen_ip);
+ cmd.add(port);
+ cmd.add(io_thread_num);
+ cmd.add(decoder_thread_num);
+ cmd.add(model_thread_num);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
+
+ std::string s_listen_ip = listen_ip.getValue();
+ int s_port = port.getValue();
+ int s_io_thread_num = io_thread_num.getValue();
+ int s_decoder_thread_num = decoder_thread_num.getValue();
+
+ int s_model_thread_num = model_thread_num.getValue();
+
+ asio::io_context io_decoder; // context for decoding
+
+ std::vector<std::thread> decoder_threads;
+
+ auto conn_guard = asio::make_work_guard(
+ io_decoder); // make sure threads can wait in the queue
+
+ // create threads pool
+ for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
+ decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
+ }
+
+ server server_; // server for websocket
+ server_.init_asio(); // init asio
+ server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
+
+ // list on port for accept
+ server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+
+ WebSocketServer websocket_srv(io_decoder,
+ &server_); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+ std::cout << "asr model init finished. listen on port:" << s_port
+ << std::endl;
+
+ // Start the ASIO network io_service run loop
+ if (s_io_thread_num == 1) {
+ server_.run();
+ } else {
+ typedef websocketpp::lib::shared_ptr<websocketpp::lib::thread> thread_ptr;
+ std::vector<thread_ptr> ts;
+ // create threads for io network
+ for (size_t i = 0; i < s_io_thread_num; i++) {
+ ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
+ &server::run, &server_));
+ }
+ // wait for theads
+ for (size_t i = 0; i < s_io_thread_num; i++) {
+ ts[i]->join();
+ }
+ }
+
+ // wait for theads
+ for (auto& t : decoder_threads) {
+ t.join();
+ }
+
+ } catch (std::exception const& e) {
+ std::cerr << "Error: " << e.what() << std::endl;
+ }
+
+ return 0;
+}
\ No newline at end of file
diff --git a/funasr/runtime/websocket/websocketsrv.cpp b/funasr/runtime/websocket/websocketsrv.cpp
new file mode 100644
index 0000000..9e56667
--- /dev/null
+++ b/funasr/runtime/websocket/websocketsrv.cpp
@@ -0,0 +1,158 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// websocket server for asr engine
+// take some ideas from https://github.com/k2-fsa/sherpa-onnx
+// online-websocket-server-impl.cc, thanks. The websocket server has two threads
+// pools, one for handle network data and one for asr decoder.
+// now only support offline engine.
+
+#include "websocketsrv.h"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+// feed buffer to asr engine for decoder
+void WebSocketServer::do_decoder(const std::vector<char>& buffer,
+ websocketpp::connection_hdl& hdl) {
+ try {
+ int num_samples = buffer.size(); // the size of the buf
+
+ if (!buffer.empty()) {
+ // fout.write(buffer.data(), buffer.size());
+ // feed data to asr engine
+ FUNASR_RESULT Result = FunOfflineRecogPCMBuffer(
+ asr_hanlde, buffer.data(), buffer.size(), 16000, RASR_NONE, NULL);
+
+ std::string asr_result =
+ ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
+
+ websocketpp::lib::error_code ec;
+ nlohmann::json jsonresult; // result json
+ jsonresult["text"] = asr_result; // put result in 'text'
+
+ // send the json to client
+ server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
+
+ std::cout << "buffer.size=" << buffer.size()
+ << ",result json=" << jsonresult.dump() << std::endl;
+ if (!isonline) {
+ // close the client if it is not online asr
+ server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
+ // fout.close();
+ }
+ }
+
+ } catch (std::exception const& e) {
+ std::cerr << "Error: " << e.what() << std::endl;
+ }
+}
+
+void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
+ scoped_lock guard(m_lock); // for threads safty
+ check_and_clean_connection(); // remove closed connection
+ sample_map.emplace(
+ hdl, std::make_shared<std::vector<char>>()); // put a new data vector for
+ // new connection
+ std::cout << "on_open, active connections: " << sample_map.size()
+ << std::endl;
+}
+
+void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
+ scoped_lock guard(m_lock);
+ sample_map.erase(hdl); // remove data vector when connection is closed
+ std::cout << "on_close, active connections: " << sample_map.size()
+ << std::endl;
+}
+
+// remove closed connection
+void WebSocketServer::check_and_clean_connection() {
+ std::vector<websocketpp::connection_hdl> to_remove; // remove list
+ auto iter = sample_map.begin();
+ while (iter != sample_map.end()) { // loop to find closed connection
+ websocketpp::connection_hdl hdl = iter->first;
+ server::connection_ptr con = server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
+ iter++;
+ }
+ for (auto hdl : to_remove) {
+ sample_map.erase(hdl);
+ std::cout << "remove one connection " << std::endl;
+ }
+}
+void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
+ message_ptr msg) {
+ unique_lock lock(m_lock);
+ // find the sample data vector according to one connection
+ std::shared_ptr<std::vector<char>> sample_data_p = nullptr;
+
+ auto it = sample_map.find(hdl);
+ if (it != sample_map.end()) {
+ sample_data_p = it->second;
+ }
+ lock.unlock();
+ if (sample_data_p == nullptr) {
+ std::cout << "error when fetch sample data vector" << std::endl;
+ return;
+ }
+
+ const std::string& payload = msg->get_payload(); // get msg type
+
+ switch (msg->get_opcode()) {
+ case websocketpp::frame::opcode::text:
+ if (payload == "Done") {
+ std::cout << "client done" << std::endl;
+
+ if (isonline) {
+ // do_close(ws);
+ } else {
+ // for offline, send all receive data to decoder engine
+ asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
+ std::move(*(sample_data_p.get())),
+ std::move(hdl)));
+ }
+ }
+ break;
+ case websocketpp::frame::opcode::binary: {
+ // recived binary data
+ const auto* pcm_data = static_cast<const char*>(payload.data());
+ int32_t num_samples = payload.size();
+
+ if (isonline) {
+ // if online TODO(zhaoming) still not done
+ std::vector<char> s(pcm_data, pcm_data + num_samples);
+ asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
+ std::move(s), std::move(hdl)));
+ } else {
+ // for offline, we add receive data to end of the sample data vector
+ sample_data_p->insert(sample_data_p->end(), pcm_data,
+ pcm_data + num_samples);
+ }
+
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+// init asr model
+void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
+ int thread_num) {
+ try {
+ // init model with api
+
+ asr_hanlde = FunOfflineInit(model_path, thread_num);
+ std::cout << "model ready" << std::endl;
+
+ } catch (const std::exception& e) {
+ std::cout << e.what() << std::endl;
+ }
+}
diff --git a/funasr/runtime/websocket/websocketsrv.h b/funasr/runtime/websocket/websocketsrv.h
new file mode 100644
index 0000000..e484724
--- /dev/null
+++ b/funasr/runtime/websocket/websocketsrv.h
@@ -0,0 +1,93 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// websocket server for asr engine
+// take some ideas from https://github.com/k2-fsa/sherpa-onnx
+// online-websocket-server-impl.cc, thanks. The websocket server has two threads
+// pools, one for handle network data and one for asr decoder.
+// now only support offline engine.
+
+#ifndef WEBSOCKETSRV_SERVER_H_
+#define WEBSOCKETSRV_SERVER_H_
+
+#include <iostream>
+#include <map>
+#include <memory>
+#include <string>
+#include <thread>
+#include <utility>
+#define ASIO_STANDALONE 1 // not boost
+#include <glog/logging.h>
+
+#include <fstream>
+#include <functional>
+#include <websocketpp/common/thread.hpp>
+#include <websocketpp/config/asio_no_tls.hpp>
+#include <websocketpp/server.hpp>
+
+#include "asio.hpp"
+#include "com-define.h"
+#include "funasrruntime.h"
+#include "nlohmann/json.hpp"
+#include "tclap/CmdLine.h"
+typedef websocketpp::server<websocketpp::config::asio> server;
+typedef server::message_ptr message_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
+typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
+
+typedef struct {
+ std::string msg;
+ float snippet_time;
+} FUNASR_RECOG_RESULT;
+
+class WebSocketServer {
+ public:
+ WebSocketServer(asio::io_context& io_decoder, server* server_)
+ : io_decoder_(io_decoder), server_(server_) {
+ // set message handle
+ server_->set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+ // set open handle
+ server_->set_open_handler(
+ [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+ // set close handle
+ server_->set_close_handler(
+ [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+ // begin accept
+ server_->start_accept();
+ // not print log
+ server_->clear_access_channels(websocketpp::log::alevel::all);
+ }
+ void do_decoder(const std::vector<char>& buffer,
+ websocketpp::connection_hdl& hdl);
+
+ void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
+ void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
+ void on_open(websocketpp::connection_hdl hdl);
+ void on_close(websocketpp::connection_hdl hdl);
+
+ private:
+ void check_and_clean_connection();
+ asio::io_context& io_decoder_; // threads for asr decoder
+ // std::ofstream fout;
+ FUNASR_HANDLE asr_hanlde; // asr engine handle
+ bool isonline = false; // online or offline engine, now only support offline
+ server* server_; // websocket server
+
+ // use map to keep the received samples data from one connection in offline
+ // engine. if for online engline, a data struct is needed(TODO)
+ std::map<websocketpp::connection_hdl, std::shared_ptr<std::vector<char>>,
+ std::owner_less<websocketpp::connection_hdl>>
+ sample_map;
+ websocketpp::lib::mutex m_lock; // mutex for sample_map
+};
+
+#endif // WEBSOCKETSRV_SERVER_H_
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 3d2004c..55a5d79 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -445,6 +445,12 @@
help='Perform on "collect stats" mode',
)
group.add_argument(
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+ group.add_argument(
"--write_collected_feats",
type=str2bool,
default=False,
@@ -549,6 +555,12 @@
help="The number of gradient accumulation",
)
group.add_argument(
+ "--bias_grad_times",
+ type=float,
+ default=1.0,
+ help="To scale the gradient of contextual related params",
+ )
+ group.add_argument(
"--no_forward_run",
type=str2bool,
default=False,
@@ -635,8 +647,8 @@
group.add_argument(
"--init_param",
type=str,
+ action="append",
default=[],
- nargs="*",
help="Specify the file path used for initialization of parameters. "
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
"where file_path is the model file path, "
@@ -662,7 +674,7 @@
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
@@ -1153,10 +1165,10 @@
elif args.distributed and args.simple_ddp:
distributed_option.init_torch_distributed_pai(args)
args.ngpu = dist.get_world_size()
- if args.dataset_type == "small":
+ if args.dataset_type == "small" and args.ngpu > 0:
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
+ if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
# filter samples if wav.scp and text are mismatch
@@ -1316,6 +1328,7 @@
data_path_and_name_and_type=args.train_data_path_and_name_and_type,
key_file=train_key_file,
batch_size=args.batch_size,
+ mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
@@ -1327,6 +1340,7 @@
data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
key_file=valid_key_file,
batch_size=args.valid_batch_size,
+ mc=args.mc,
dtype=args.train_dtype,
num_workers=args.num_workers,
allow_variable_data_keys=args.allow_variable_data_keys,
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index d52c9c3..43ea5ab 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -42,6 +42,7 @@
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
@@ -128,6 +129,7 @@
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
+ neatcontextual_paraformer=NeatContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
),
@@ -1647,7 +1649,6 @@
normalize = None
# 4. Encoder
-
if getattr(args, "encoder", None) is not None:
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size, **args.encoder_conf)
@@ -1683,7 +1684,7 @@
# 7. Build model
- if encoder.unified_model_training:
+ if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
new file mode 100644
index 0000000..7cfcbd0
--- /dev/null
+++ b/funasr/tasks/sa_asr.py
@@ -0,0 +1,623 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+import yaml
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+ DynamicConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolutionTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34,ResNet34Diar
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ multichannelfrontend=MultiChannelFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ asr=ESPnetASRModel,
+ uniasr=UniASR,
+ paraformer=Paraformer,
+ paraformer_bert=ParaformerBert,
+ bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
+ mfcca=MFCCA,
+ timestamp_prediction=TimestampPredictor,
+ ),
+ type_check=AbsESPnetModel,
+ default="asr",
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ linear=LinearProjection,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
+
+encoder_choices2 = ClassChoices(
+ "encoder2",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+postencoder_choices = ClassChoices(
+ name="postencoder",
+ classes=dict(
+ hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+ ),
+ type_check=AbsPostEncoder,
+ default=None,
+ optional=True,
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="sa_decoder",
+)
+decoder_choices2 = ClassChoices(
+ "decoder2",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="rnn",
+)
+predictor_choices = ClassChoices(
+ name="predictor",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ cif_predictor_v3=CifPredictorV3,
+ ),
+ type_check=None,
+ default="cif_predictor",
+ optional=True,
+)
+predictor_choices2 = ClassChoices(
+ name="predictor2",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ ),
+ type_check=None,
+ default="cif_predictor",
+ optional=True,
+)
+stride_conv_choices = ClassChoices(
+ name="stride_conv",
+ classes=dict(
+ stride_conv1d=Conv1dSubsampling
+ ),
+ type_check=None,
+ default="stride_conv1d",
+ optional=True,
+)
+
+
+class ASRTask(AbsTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
+ # --postencoder and --postencoder_conf
+ postencoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(description="Task related")
+
+ # NOTE(kamo): add_arguments(..., required=True) can't be used
+ # to provide --print_config mode. Instead of it, do as
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--max_spk_num",
+ type=int_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--seg_dict_file",
+ type=str,
+ default=None,
+ help="seg_dict_file for text processing",
+ )
+ group.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ group.add_argument(
+ "--ctc_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(CTC),
+ help="The keyword arguments for CTC class.",
+ )
+ group.add_argument(
+ "--joint_net_conf",
+ action=NestedDictAction,
+ default=None,
+ help="The keyword arguments for joint network class.",
+ )
+
+ group = parser.add_argument_group(description="Preprocess related")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word", "phn"],
+ help="The text will be tokenized " "in the specified level token",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file of sentencepiece",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ default=None,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Scale the maximum amplitude to the given value.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of rir scp file.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="THe probability for applying RIR convolution.",
+ )
+ parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability applying Noise adding.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ assert check_argument_types()
+ # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ assert check_argument_types()
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+ # NOTE(kamo): Check attribute existence for backward compatibility
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ if not inference:
+ retval = ("speech", "text")
+ else:
+ # Recognition mode
+ retval = ("speech",)
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ retval = ()
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+ # 6. Post-encoder block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ asr_encoder_output_size = asr_encoder.output_size()
+ if getattr(args, "postencoder", None) is not None:
+ postencoder_class = postencoder_choices.get_class(args.postencoder)
+ postencoder = postencoder_class(
+ input_size=asr_encoder_output_size, **args.postencoder_conf
+ )
+ asr_encoder_output_size = postencoder.output_size()
+ else:
+ postencoder = None
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder_output_size,
+ **args.decoder_conf,
+ )
+
+ # 8. CTC
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder_output_size, **args.ctc_conf
+ )
+
+ max_spk_num=int(args.max_spk_num)
+
+ # import ipdb;ipdb.set_trace()
+ # 9. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+ model = model_class(
+ vocab_size=vocab_size,
+ max_spk_num=max_spk_num,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
diff --git a/funasr/torch_utils/load_pretrained_model.py b/funasr/torch_utils/load_pretrained_model.py
index e9b18cd..b54f777 100644
--- a/funasr/torch_utils/load_pretrained_model.py
+++ b/funasr/torch_utils/load_pretrained_model.py
@@ -120,6 +120,6 @@
if ignore_init_mismatch:
src_state = filter_state_dict(dst_state, src_state)
- logging.info("Loaded src_state keys: {}".format(src_state.keys()))
+ # logging.info("Loaded src_state keys: {}".format(src_state.keys()))
dst_state.update(src_state)
obj.load_state_dict(dst_state)
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 7c187e9..a40f031 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -95,6 +95,7 @@
use_pai: bool
oss_bucket: Union[oss2.Bucket, None]
batch_interval: int
+ bias_grad_times: float
class Trainer:
"""Trainer having a optimizer.
@@ -546,8 +547,11 @@
no_forward_run = options.no_forward_run
ngpu = options.ngpu
use_wandb = options.use_wandb
+ bias_grad_times = options.bias_grad_times
distributed = distributed_option.distributed
+ if bias_grad_times != 1.0:
+ logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
if log_interval is None:
try:
log_interval = max(len(iterator) // 20, 10)
@@ -690,6 +694,16 @@
scale_factor=0.55,
)
+ # for contextual training
+ if bias_grad_times != 1.0:
+ # contextual related parameter names
+ cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
+ for name, param in model.named_parameters():
+ for cr_pname in cr_pnames:
+ if cr_pname in name:
+ param.grad *= bias_grad_times
+ continue
+
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index b607e1d..f4efea6 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -242,4 +242,4 @@
if ch != ' ':
real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
- return sentence, real_word_lists
+ return sentence, real_word_lists
\ No newline at end of file
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 87cc49e..489d317 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -94,19 +94,33 @@
res.append({
'text': text_postprocessed.split(),
"start": time_stamp_postprocessed[0][0],
- "end": time_stamp_postprocessed[-1][1]
+ "end": time_stamp_postprocessed[-1][1],
+ 'text_seg': text_postprocessed.split(),
+ "ts_list": time_stamp_postprocessed,
})
return res
if len(punc_id_list) != len(time_stamp_postprocessed):
print(" warning length mistach!!!!!!")
- sentence_text = ''
+ sentence_text = ""
+ sentence_text_seg = ""
+ ts_list = []
sentence_start = time_stamp_postprocessed[0][0]
sentence_end = time_stamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
for punc_stamp_text in punc_stamp_text_list:
punc_id, time_stamp, text = punc_stamp_text
- sentence_text += text if text is not None else ''
+ # sentence_text += text if text is not None else ''
+ if text is not None:
+ if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
+ sentence_text += ' ' + text
+ elif len(sentence_text) and ('a' <= sentence_text[-1] <= 'z' or 'A' <= sentence_text[-1] <= 'Z'):
+ sentence_text += ' ' + text
+ else:
+ sentence_text += text
+ sentence_text_seg += text + ' '
+ ts_list.append(time_stamp)
+
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
@@ -115,27 +129,39 @@
res.append({
'text': sentence_text,
"start": sentence_start,
- "end": sentence_end
+ "end": sentence_end,
+ "text_seg": sentence_text_seg,
+ "ts_list": ts_list
})
sentence_text = ''
+ sentence_text_seg = ''
+ ts_list = []
sentence_start = sentence_end
elif punc_id == 3:
sentence_text += '.'
res.append({
'text': sentence_text,
"start": sentence_start,
- "end": sentence_end
+ "end": sentence_end,
+ "text_seg": sentence_text_seg,
+ "ts_list": ts_list
})
sentence_text = ''
+ sentence_text_seg = ''
+ ts_list = []
sentence_start = sentence_end
elif punc_id == 4:
sentence_text += '?'
res.append({
'text': sentence_text,
"start": sentence_start,
- "end": sentence_end
+ "end": sentence_end,
+ "text_seg": sentence_text_seg,
+ "ts_list": ts_list
})
sentence_text = ''
+ sentence_text_seg = ''
+ ts_list = []
sentence_start = sentence_end
return res
diff --git a/funasr/version.txt b/funasr/version.txt
index 6f2743d..cb498ab 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.4.4
+0.4.8
diff --git a/setup.py b/setup.py
index e837637..ea55606 100644
--- a/setup.py
+++ b/setup.py
@@ -13,7 +13,7 @@
"install": [
"setuptools>=38.5.1",
# "configargparse>=1.2.1",
- "typeguard<=2.13.3",
+ "typeguard==2.13.3",
"humanfriendly",
"scipy>=1.4.1",
# "filelock",
@@ -42,7 +42,10 @@
"oss2",
# "kaldi-native-fbank",
# timestamp
- "edit-distance"
+ "edit-distance",
+ # textgrid
+ "textgrid",
+ "protobuf==3.20.0",
],
# train: The modules invoked when training only.
"train": [
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index 2f2f11d..9098ea6 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -112,6 +112,22 @@
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ def test_paraformer_large_online_common(self):
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online')
+ rec_result = inference_pipeline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+ logger.info("asr inference result: {0}".format(rec_result))
+
+ def test_paraformer_online_common(self):
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online')
+ rec_result = inference_pipeline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+ logger.info("asr inference result: {0}".format(rec_result))
+
def test_paraformer_tiny_commandword(self):
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
--
Gitblit v1.9.1