From 012903e42ec890ab5c50137beb365c3d94e731d1 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 30 六月 2023 11:21:28 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/models/e2e_asr.py | 2
funasr/datasets/small_datasets/preprocessor.py | 11
funasr/runtime/html5/static/wsconnecter.js | 14
funasr/runtime/python/websocket/wss_srv_asr.py | 8
egs/alimeeting/sa_asr/local/gen_oracle_embedding.py | 6
egs/alimeeting/sa_asr/local/download_xvector_model.py | 0
egs/alimeeting/sa_asr/local/text_normalize.pl | 0
funasr/bin/lm_inference_launch.py | 129
funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs | 717 +++
funasr/runtime/readme_cn.md | 31
egs/alimeeting/sa_asr/README.md | 86
funasr/modules/eend_ola/utils/__init__.py | 0
funasr/runtime/onnxruntime/src/tokenizer.h | 2
funasr/tasks/sa_asr.py | 14
funasr/text/cleaner.py | 2
egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl | 0
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py | 3
funasr/modules/eend_ola/utils/report.py | 2
funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs | 17
funasr/runtime/websocket/readme.md | 94
.github/workflows/UnitTest.yml | 7
tests/test_asr_inference_pipeline.py | 6
funasr/main_funcs/collect_stats.py | 2
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py | 3
funasr/models/encoder/conformer_encoder.py | 58
docs/m2met2/_build/doctrees/environment.pickle | 0
docs/m2met2/index.rst | 1
funasr/tasks/diar.py | 16
funasr/torch_utils/initialize.py | 2
funasr/bin/punc_inference_launch.py | 108
funasr/layers/mask_along_axis.py | 4
egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py | 6
funasr/text/sentencepiece_tokenizer.py | 2
docs/m2met2_cn/_build/html/genindex.html | 1
funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll | 0
funasr/bin/asr_infer.py | 702 +-
egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh | 0
egs/alimeeting/sa_asr/utils | 0
funasr/layers/label_aggregation.py | 2
docs/m2met2/_build/html/Baseline.html | 1
funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs | 27
funasr/models/e2e_vad.py | 23
funasr/runtime/docs/SDK_tutorial.md | 336 +
egs/alimeeting/sa_asr/local/convert_model.py | 29
docs/m2met2/_build/html/Dataset.html | 1
egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml | 0
funasr/models/decoder/rnnt_decoder.py | 2
funasr/runtime/docs/images/aliyun6.png | 0
funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs | 6
funasr/models/e2e_diar_eend_ola.py | 2
funasr/datasets/collate_fn.py | 7
docs/m2met2_cn/_build/html/组委会.html | 9
funasr/models/e2e_tp.py | 2
funasr/bin/punc_infer.py | 60
docs/m2met2/_build/html/_sources/index.rst.txt | 1
funasr/runtime/python/websocket/wss_client_asr.py | 275
funasr/train/abs_model.py | 3
funasr/runtime/csharp/AliFsmnVadSharp.sln | 37
funasr/text/build_tokenizer.py | 2
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml | 26
funasr/build_utils/build_diar_model.py | 22
funasr/models/preencoder/sinc.py | 3
funasr/models/decoder/sanm_decoder.py | 3
funasr/bin/vad_infer.py | 44
tests/test_tp_pipeline.py | 30
egs/alimeeting/sa_asr/local/compute_cpcer.py | 0
funasr/bin/train.py | 4
egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh | 0
funasr/runtime/docs/benchmark_onnx.md | 0
funasr/train/class_choices.py | 5
funasr/datasets/dataset.py | 6
egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py | 0
funasr/layers/utterance_mvn.py | 2
egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py | 0
funasr/text/word_tokenizer.py | 2
funasr/models/encoder/rnn_encoder.py | 2
docs/runtime/img.png | 0
funasr/train/reporter.py | 9
funasr/runtime/onnxruntime/src/offline-stream.cpp | 26
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp | 6
docs/m2met2_cn/_build/html/searchindex.js | 2
funasr/models/postencoder/hugging_face_transformers_postencoder.py | 2
funasr/runtime/docs/images/aliyun3.png | 0
docs/m2met2/Challenge_result.md | 14
funasr/samplers/length_batch_sampler.py | 2
funasr/utils/prepare_data.py | 7
funasr/main_funcs/average_nbest_models.py | 2
funasr/samplers/sorted_batch_sampler.py | 2
funasr/models/e2e_asr_contextual_paraformer.py | 6
funasr/schedulers/warmup_lr.py | 2
docs/m2met2_cn/_build/doctrees/比赛结果.doctree | 0
funasr/runtime/docs/benchmark_libtorch.md | 0
funasr/runtime/onnxruntime/src/paraformer.cpp | 1
funasr/runtime/python/websocket/parse_args.py | 2
tests/test_asr_vad_punc_inference_pipeline.py | 1
docs/m2met2/_build/doctrees/Challenge_result.doctree | 0
docs/m2met2/_build/html/Track_setting_and_evaluation.html | 1
docs/m2met2_cn/_build/html/search.html | 1
funasr/layers/global_mvn.py | 2
egs/alimeeting/sa_asr/local/fix_data_dir.sh | 0
funasr/datasets/large_datasets/build_dataloader.py | 2
funasr/train/trainer.py | 5
setup.py | 53
funasr/runtime/ssl_key/readme.md | 4
funasr/export/export_model.py | 65
funasr/runtime/docs/images/aliyun5.png | 0
funasr/utils/build_dataclass.py | 2
funasr/runtime/websocket/funasr-wss-server.cpp | 329 +
funasr/runtime/ssl_key/server.key | 38
egs/alimeeting/sa_asr/local/compute_cmvn.py | 134
egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh | 0
funasr/torch_utils/forward_adaptor.py | 2
funasr/bin/tp_inference_launch.py | 118
funasr/layers/stft.py | 2
funasr/models/encoder/transformer_encoder.py | 2
funasr/build_utils/build_lm_model.py | 9
funasr/build_utils/build_asr_model.py | 95
funasr/runtime/onnxruntime/include/com-define.h | 1
egs/alimeeting/sa_asr/local/format_wav_scp.sh | 0
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/finetune.py | 38
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/utils | 1
funasr/datasets/large_datasets/dataset.py | 11
egs/alimeeting/sa_asr/local/data/get_reco2dur.sh | 0
funasr/runtime/docs/images/aliyun11.png | 0
funasr/runtime/websocket/websocket-server.cpp | 30
egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py | 0
docs/m2met2_cn/_build/html/基线.html | 1
funasr/bin/sv_infer.py | 63
docs/installation/installation.md | 14
egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py | 0
egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py | 0
funasr/models/seq_rnn_lm.py | 2
funasr/fileio/read_text.py | 3
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py | 3
funasr/runtime/docs/images/aliyun10.png | 0
.gitignore | 3
funasr/samplers/build_batch_sampler.py | 4
funasr/iterators/multiple_iter_factory.py | 2
funasr/models/e2e_asr_mfcca.py | 4
funasr/models/e2e_sa_asr.py | 19
funasr/runtime/onnxruntime/src/tokenizer.cpp | 8
funasr/samplers/folded_batch_sampler.py | 2
funasr/models/encoder/mfcca_encoder.py | 2
funasr/samplers/unsorted_batch_sampler.py | 2
funasr/build_utils/build_streaming_iterator.py | 65
funasr/runtime/docs/images/aliyun9.png | 0
funasr/runtime/onnxruntime/src/commonfunc.h | 5
docs/benchmark/benchmark_libtorch.md | 2
funasr/models/e2e_uni_asr.py | 6
funasr/runtime/onnxruntime/include/funasrruntime.h | 11
funasr/models/encoder/opennmt_encoders/fsmn_encoder.py | 1
funasr/runtime/html5/static/main.js | 261 +
docs/images/dingding.jpg | 0
funasr/runtime/onnxruntime/src/punc-model.cpp | 12
funasr/utils/wav_utils.py | 15
egs/alimeeting/sa_asr/local/compute_cmvn.sh | 39
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 4
egs/alimeeting/sa_asr/local/download_and_untar.sh | 105
docs/m2met2/_build/html/_sources/Challenge_result.md.txt | 14
funasr/runtime/python/libtorch/setup.py | 11
egs/alimeeting/sa_asr/local/copy_data_dir.sh | 0
docs/m2met2_cn/_build/html/objects.inv | 0
funasr/runtime/websocket/CMakeLists.txt | 13
funasr/build_utils/build_model.py | 5
funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs | 185
funasr/models/frontend/default.py | 139
funasr/models/ctc.py | 7
funasr/runtime/onnxruntime/src/ct-transformer-online.h | 37
funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs | 61
funasr/runtime/onnxruntime/src/ct-transformer.cpp | 1
funasr/datasets/small_datasets/length_batch_sampler.py | 2
funasr/text/token_id_converter.py | 2
funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs | 22
funasr/runtime/onnxruntime/src/ct-transformer-online.cpp | 283 +
funasr/build_utils/build_trainer.py | 5
docs/m2met2_cn/index.rst | 1
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.sh | 105
funasr/optimizers/sgd.py | 2
funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs | 72
funasr/tasks/vad.py | 8
funasr/tasks/punctuation.py | 9
funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs | 23
funasr/models/encoder/opennmt_encoders/conv_encoder.py | 2
docs/m2met2_cn/_build/doctrees/environment.pickle | 0
funasr/runtime/onnxruntime/src/precomp.h | 1
docs/m2met2_cn/_build/html/联系方式.html | 1
funasr/runtime/docs/images/aliyun4.png | 0
funasr/bin/sv_inference_launch.py | 113
funasr/models/encoder/sanm_encoder.py | 4
funasr/tasks/lm.py | 8
funasr/runtime/docs/images/aliyun12.png | 0
funasr/tasks/abs_task.py | 16
egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh | 50
funasr/fileio/datadir_writer.py | 6
funasr/runtime/docs/aliyun_server_tutorial.md | 74
funasr/models/encoder/data2vec_encoder.py | 2
docs/m2met2/_build/html/objects.inv | 0
funasr/runtime/websocket/websocket-server.h | 6
docs/benchmark/benchmark_onnx_cpp.md | 2
egs/alimeeting/sa_asr/local/text_format.pl | 0
funasr/models/e2e_diar_sond.py | 2
funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj | 37
docs/m2met2_cn/_build/html/index.html | 2
funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs | 23
funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj | 18
funasr/samplers/num_elements_batch_sampler.py | 2
docs/m2met2/_build/html/Rules.html | 9
docs/index.rst | 3
egs/alimeeting/sa_asr/local/validate_text.pl | 0
funasr/models/frontend/windowing.py | 2
funasr/fileio/rand_gen_dataset.py | 3
egs/alimeeting/sa_asr/local/data/split_data.sh | 0
funasr/datasets/iterable_dataset.py | 14
funasr/build_utils/build_model_from_file.py | 191
funasr/text/phoneme_tokenizer.py | 2
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 4
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py | 1
funasr/models/decoder/transformer_decoder.py | 10
funasr/runtime/docs/images/aliyun8.png | 0
funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs | 156
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml | 27
funasr/build_utils/build_vad_model.py | 4
funasr/models/decoder/rnn_decoder.py | 2
funasr/utils/griffin_lim.py | 2
docs/m2met2_cn/_build/html/赛道设置与评估.html | 1
funasr/runtime/html5/readme_cn.md | 2
egs/alimeeting/sa_asr/local/apply_map.pl | 0
funasr/models/e2e_asr_transducer.py | 15
docs/m2met2_cn/_build/html/规则.html | 9
egs/alimeeting/sa_asr/local/format_wav_scp.py | 2
funasr/bin/vad_inference_launch.py | 62
funasr/models/e2e_asr_paraformer.py | 6
funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs | 29
funasr/bin/diar_inference_launch.py | 68
funasr/runtime/ssl_key/server.crt | 32
funasr/utils/runtime_sdk_download_tool.py | 39
funasr/build_utils/build_sv_model.py | 256 +
docs/m2met2/_build/html/genindex.html | 1
egs/alimeeting/sa_asr/path.sh | 6
funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs | 387 +
funasr/runtime/onnxruntime/include/punc-model.h | 7
docs/m2met2/_build/html/Organizers.html | 9
docs/m2met2_cn/_build/html/_sources/index.rst.txt | 1
funasr/runtime/docs/images/aliyun1.png | 0
funasr/runtime/java/readme.md | 66
docs/m2met2/_build/html/Challenge_result.html | 247 +
funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs | 40
funasr/utils/asr_utils.py | 6
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/demo.py | 12
egs/alimeeting/sa_asr/local/data/get_utt2dur.sh | 0
funasr/build_utils/build_args.py | 19
funasr/models/e2e_sv.py | 2
funasr/models/preencoder/linear.py | 2
funasr/datasets/small_datasets/dataset.py | 6
egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl | 0
funasr/bin/diar_infer.py | 116
funasr/runtime/java/Makefile | 76
docs/m2met2_cn/_build/doctrees/index.doctree | 0
funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs | 98
docs/m2met2/_build/html/Contact.html | 1
funasr/runtime/docs/benchmark_onnx_cpp.md | 0
funasr/models/data2vec.py | 2
funasr/text/char_tokenizer.py | 2
funasr/runtime/docs/images/aliyun7.png | 0
funasr/runtime/onnxruntime/bin/CMakeLists.txt | 3
funasr/runtime/html5/readme.md | 16
funasr/runtime/html5/static/index.html | 32
docs/m2met2/_build/html/index.html | 2
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py | 2
README.md | 73
docs/m2met2_cn/_build/html/简介.html | 1
funasr/schedulers/noam_lr.py | 2
egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py | 3
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py | 5
funasr/datasets/small_datasets/collate_fn.py | 5
funasr/datasets/small_datasets/__init__.py | 0
egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py | 4
funasr/runtime/docs/SDK_tutorial_cn.md | 327 +
funasr/iterators/chunk_iter_factory.py | 2
funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp | 130
funasr/fileio/sound_scp.py | 73
funasr/version.txt | 2
egs/alimeeting/sa_asr/local/combine_data.sh | 0
funasr/runtime/onnxruntime/src/funasrruntime.cpp | 68
docs/m2met2/_build/html/search.html | 1
funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs | 19
egs/alimeeting/sa_asr/local/process_text_spk_merge.py | 0
funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs | 35
egs/alimeeting/sa_asr/run.sh | 435 ++
funasr/iterators/sequence_iter_factory.py | 2
funasr/models/frontend/fused.py | 2
docs/m2met2_cn/_build/html/比赛结果.html | 248 +
funasr/runtime/csharp/README.md | 59
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py | 2
funasr/datasets/preprocessor.py | 11
funasr/models/frontend/s3prl.py | 2
docs/m2met2/_build/html/searchindex.js | 2
funasr/fileio/npy_scp.py | 3
funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs | 26
tests/test_vad_inference_pipeline.py | 2
docs/benchmark/benchmark_pipeline_cer.md | 256
docs/m2met2_cn/_build/html/_sources/比赛结果.md.txt | 14
egs/alimeeting/sa_asr/local/process_text_id.py | 0
funasr/layers/sinc_conv.py | 4
funasr/runtime/python/onnxruntime/setup.py | 3
funasr/runtime/docs/images/aliyun2.png | 0
funasr/runtime/docs/SDK_advanced_guide_cn.md | 261 +
funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py | 2
docs/m2met2/_build/html/Introduction.html | 1
funasr/tasks/sv.py | 9
funasr/models/decoder/contextual_decoder.py | 2
docs/m2met2_cn/_build/html/数据集.html | 1
funasr/tasks/data2vec.py | 8
docs/m2met2/_build/doctrees/index.doctree | 0
funasr/tasks/asr.py | 170
funasr/modules/subsampling.py | 21
funasr/runtime/readme.md | 30
funasr/models/frontend/wav_frontend.py | 4
funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py | 1
/dev/null | 190
funasr/runtime/websocket/funasr-wss-client.cpp | 378 +
egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py | 0
funasr/schedulers/tri_stage_scheduler.py | 2
funasr/bin/tp_infer.py | 68
funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs | 28
docs/benchmark/benchmark_onnx.md | 2
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/README.md | 3
docs/m2met2_cn/比赛结果.md | 14
tests/test_sv_inference_pipeline.py | 38
funasr/bin/tokenize_text.py | 2
egs/alimeeting/sa_asr/local/validate_data_dir.sh | 0
funasr/runtime/java/FunasrWsClient.java | 344 +
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/infer.py | 33
funasr/bin/asr_inference_launch.py | 885 ++--
funasr/runtime/onnxruntime/src/vocab.cpp | 9
funasr/runtime/python/libtorch/funasr_torch/utils/frontend.py | 2
336 files changed, 9,694 insertions(+), 2,840 deletions(-)
diff --git a/.github/workflows/UnitTest.yml b/.github/workflows/UnitTest.yml
index 3b0a1ee..8ced9e4 100644
--- a/.github/workflows/UnitTest.yml
+++ b/.github/workflows/UnitTest.yml
@@ -8,6 +8,7 @@
branches:
- dev_wjm
- dev_jy
+ - dev_wjm_infer
jobs:
build:
@@ -18,6 +19,12 @@
python-version: ["3.7"]
steps:
+ - name: Remove unnecessary files
+ run:
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
diff --git a/.gitignore b/.gitignore
index 58bee36..d47674c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,4 +18,5 @@
build
funasr.egg-info
docs/_build
-modelscope
\ No newline at end of file
+modelscope
+samples
\ No newline at end of file
diff --git a/README.md b/README.md
index 7c289e0..26cf940 100644
--- a/README.md
+++ b/README.md
@@ -12,7 +12,7 @@
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
| [**Highlights**](#highlights)
| [**Installation**](#installation)
-| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
+| [**Usage**](#usage)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md)
@@ -34,9 +34,9 @@
Install from pip
```shell
-pip install -U funasr
+pip3 install -U funasr
# For the users in China, you could install with the command:
-# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
Or install from source code
@@ -44,22 +44,71 @@
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install -e ./
+pip3 install -e ./
# For the users in China, you could install with the command:
-# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
If you want to use the pretrained models in ModelScope, you should install the modelscope:
```shell
-pip install -U modelscope
+pip3 install -U modelscope
# For the users in China, you could install with the command:
-# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
For more details, please ref to [installation](https://alibaba-damo-academy.github.io/FunASR/en/installation/installation.html)
+## Usage
+You could use FunASR by:
+
+- egs
+- egs_modelscope
+- runtime
+
+### egs
+If you want to train the model from scratch, you could use funasr directly by recipe, as the following:
+```shell
+cd egs/aishell/paraformer
+. ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
+```
+More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
+
+### egs_modelscope
+If you want to infer or finetune pretraining models from modelscope, you could use funasr by modelscope pipeline, as the following:
+
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+)
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
+# {'text': '娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�'}
+```
+More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
+
+### runtime
+
+An example with websocket:
+
+For the server:
+```shell
+cd funasr/runtime/python/websocket
+python wss_srv_asr.py --port 10095
+```
+
+For the client:
+```shell
+python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
+#python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
+```
+More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2)
## Contact
If you have any questions about FunASR, please contact us by
@@ -72,8 +121,8 @@
## Contributors
-| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/DeepScience.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
-|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|:-----------------------------------------------------------:|
+| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
+|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|
## Acknowledge
@@ -82,13 +131,17 @@
3. We referred [Wenet](https://github.com/wenet-e2e/wenet) for building dataloader for large scale data training.
4. We acknowledge [ChinaTelecom](https://github.com/zhuzizyf/damo-fsmn-vad-infer-httpserver) for contributing the VAD runtime.
5. We acknowledge [RapidAI](https://github.com/RapidAI) for contributing the Paraformer and CT_Transformer-punc runtime.
-6. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
6. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the websocket service and html5.
## License
This project is licensed under the [The MIT License](https://opensource.org/licenses/MIT). FunASR also contains various third-party components and some code modified from other repos under other open source licenses.
The use of pretraining model is subject to [model licencs](./MODEL_LICENSE)
+
+## Stargazers over time
+
+[](https://starchart.cc/alibaba-damo-academy/FunASR)
+
## Citations
``` bibtex
diff --git a/docs/benchmark/benchmark_libtorch.md b/docs/benchmark/benchmark_libtorch.md
index f1cd73c..04ba682 120000
--- a/docs/benchmark/benchmark_libtorch.md
+++ b/docs/benchmark/benchmark_libtorch.md
@@ -1 +1 @@
-../../funasr/runtime/python/benchmark_libtorch.md
\ No newline at end of file
+../../funasr/runtime/docs/benchmark_libtorch.md
\ No newline at end of file
diff --git a/docs/benchmark/benchmark_onnx.md b/docs/benchmark/benchmark_onnx.md
index 14e2fbe..c199094 120000
--- a/docs/benchmark/benchmark_onnx.md
+++ b/docs/benchmark/benchmark_onnx.md
@@ -1 +1 @@
-../../funasr/runtime/python/benchmark_onnx.md
\ No newline at end of file
+../../funasr/runtime/docs/benchmark_onnx.md
\ No newline at end of file
diff --git a/docs/benchmark/benchmark_onnx_cpp.md b/docs/benchmark/benchmark_onnx_cpp.md
index 3754852..c4ab108 120000
--- a/docs/benchmark/benchmark_onnx_cpp.md
+++ b/docs/benchmark/benchmark_onnx_cpp.md
@@ -1 +1 @@
-../../funasr/runtime/python/benchmark_onnx_cpp.md
\ No newline at end of file
+../../funasr/runtime/docs/benchmark_onnx_cpp.md
\ No newline at end of file
diff --git a/docs/benchmark/benchmark_pipeline_cer.md b/docs/benchmark/benchmark_pipeline_cer.md
index 9f42c95..d978f3e 100644
--- a/docs/benchmark/benchmark_pipeline_cer.md
+++ b/docs/benchmark/benchmark_pipeline_cer.md
@@ -1,4 +1,4 @@
-# Benchmark (ModeScope Pipeline)
+# Leaderboard IO
## Configuration
@@ -45,156 +45,156 @@
### Chinese Dataset
-<table>
+<table border="1">
<tr align="center">
- <td>Model</td>
- <td>Offline/Online</td>
- <td colspan="2">Aishell1</td>
- <td colspan="4">Aishell2</td>
- <td colspan="3">WenetSpeech</td>
+ <td style="border: 1px solid">Model</td>
+ <td style="border: 1px solid">Offline/Online</td>
+ <td colspan="2" style="border: 1px solid">Aishell1</td>
+ <td colspan="4" style="border: 1px solid">Aishell2</td>
+ <td colspan="3" style="border: 1px solid">WenetSpeech</td>
</tr>
<tr align="center">
- <td></td>
- <td></td>
- <td>dev</td>
- <td>test</td>
- <td>dev_ios</td>
- <td>test_ios</td>
- <td>test_android</td>
- <td>test_mic</td>
- <td>dev</td>
- <td>test_meeting</td>
- <td>test_net</td>
+ <td style="border: 1px solid"></td>
+ <td style="border: 1px solid"></td>
+ <td style="border: 1px solid">dev</td>
+ <td style="border: 1px solid">test</td>
+ <td style="border: 1px solid">dev_ios</td>
+ <td style="border: 1px solid">test_ios</td>
+ <td style="border: 1px solid">test_android</td>
+ <td style="border: 1px solid">test_mic</td>
+ <td style="border: 1px solid">dev</td>
+ <td style="border: 1px solid">test_meeting</td>
+ <td style="border: 1px solid">test_net</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
- <td>Offline</td>
- <td>1.76</td>
- <td>1.94</td>
- <td>2.79</td>
- <td>2.84</td>
- <td>3.08</td>
- <td>3.03</td>
- <td>3.43</td>
- <td>7.01</td>
- <td>6.66</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">1.76</td>
+ <td style="border: 1px solid">1.94</td>
+ <td style="border: 1px solid">2.79</td>
+ <td style="border: 1px solid">2.84</td>
+ <td style="border: 1px solid">3.08</td>
+ <td style="border: 1px solid">3.03</td>
+ <td style="border: 1px solid">3.43</td>
+ <td style="border: 1px solid">7.01</td>
+ <td style="border: 1px solid">6.66</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td>
- <td>Offline</td>
- <td>1.80</td>
- <td>2.10</td>
- <td>2.78</td>
- <td>2.87</td>
- <td>3.12</td>
- <td>3.11</td>
- <td>3.44</td>
- <td>13.28</td>
- <td>7.08</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">1.80</td>
+ <td style="border: 1px solid">2.10</td>
+ <td style="border: 1px solid">2.78</td>
+ <td style="border: 1px solid">2.87</td>
+ <td style="border: 1px solid">3.12</td>
+ <td style="border: 1px solid">3.11</td>
+ <td style="border: 1px solid">3.44</td>
+ <td style="border: 1px solid">13.28</td>
+ <td style="border: 1px solid">7.08</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
- <td>Offline</td>
- <td>1.76</td>
- <td>2.02</td>
- <td>2.73</td>
- <td>2.85</td>
- <td>2.98</td>
- <td>2.95</td>
- <td>3.42</td>
- <td>7.16</td>
- <td>6.72</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">1.76</td>
+ <td style="border: 1px solid">2.02</td>
+ <td style="border: 1px solid">2.73</td>
+ <td style="border: 1px solid">2.85</td>
+ <td style="border: 1px solid">2.98</td>
+ <td style="border: 1px solid">2.95</td>
+ <td style="border: 1px solid">3.42</td>
+ <td style="border: 1px solid">7.16</td>
+ <td style="border: 1px solid">6.72</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td>
- <td>Offline</td>
- <td>3.24</td>
- <td>3.69</td>
- <td>4.58</td>
- <td>4.63</td>
- <td>4.83</td>
- <td>4.71</td>
- <td>4.19</td>
- <td>8.32</td>
- <td>9.19</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">3.24</td>
+ <td style="border: 1px solid">3.69</td>
+ <td style="border: 1px solid">4.58</td>
+ <td style="border: 1px solid">4.63</td>
+ <td style="border: 1px solid">4.83</td>
+ <td style="border: 1px solid">4.71</td>
+ <td style="border: 1px solid">4.19</td>
+ <td style="border: 1px solid">8.32</td>
+ <td style="border: 1px solid">9.19</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td>
- <td>Online</td>
- <td>3.34</td>
- <td>3.99</td>
- <td>4.62</td>
- <td>4.52</td>
- <td>4.77</td>
- <td>4.73</td>
- <td>4.51</td>
- <td>10.63</td>
- <td>9.70</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td>
+ <td style="border: 1px solid">Online</td>
+ <td style="border: 1px solid">3.34</td>
+ <td style="border: 1px solid">3.99</td>
+ <td style="border: 1px solid">4.62</td>
+ <td style="border: 1px solid">4.52</td>
+ <td style="border: 1px solid">4.77</td>
+ <td style="border: 1px solid">4.73</td>
+ <td style="border: 1px solid">4.51</td>
+ <td style="border: 1px solid">10.63</td>
+ <td style="border: 1px solid">9.70</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td>
- <td>Offline</td>
- <td>2.93</td>
- <td>3.48</td>
- <td>3.95</td>
- <td>3.87</td>
- <td>4.11</td>
- <td>4.11</td>
- <td>4.16</td>
- <td>10.09</td>
- <td>8.69</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">2.93</td>
+ <td style="border: 1px solid">3.48</td>
+ <td style="border: 1px solid">3.95</td>
+ <td style="border: 1px solid">3.87</td>
+ <td style="border: 1px solid">4.11</td>
+ <td style="border: 1px solid">4.11</td>
+ <td style="border: 1px solid">4.16</td>
+ <td style="border: 1px solid">10.09</td>
+ <td style="border: 1px solid">8.69</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
- <td>Offline</td>
- <td>4.88</td>
- <td>5.43</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">4.88</td>
+ <td style="border: 1px solid">5.43</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
<tr align="center">
- <td> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
- <td>Offline</td>
- <td>6.14</td>
- <td>7.01</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">6.14</td>
+ <td style="border: 1px solid">7.01</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td>
- <td>Offline</td>
- <td>-</td>
- <td>-</td>
- <td>5.82</td>
- <td>6.30</td>
- <td>6.60</td>
- <td>5.83</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">5.82</td>
+ <td style="border: 1px solid">6.30</td>
+ <td style="border: 1px solid">6.60</td>
+ <td style="border: 1px solid">5.83</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
<tr align="center">
- <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td>
- <td>Offline</td>
- <td>-</td>
- <td>-</td>
- <td>4.95</td>
- <td>5.45</td>
- <td>5.59</td>
- <td>5.83</td>
- <td>-</td>
- <td>-</td>
- <td>-</td>
+ <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td>
+ <td style="border: 1px solid">Offline</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">4.95</td>
+ <td style="border: 1px solid">5.45</td>
+ <td style="border: 1px solid">5.59</td>
+ <td style="border: 1px solid">5.83</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
+ <td style="border: 1px solid">-</td>
</tr>
</table>
diff --git a/docs/images/dingding.jpg b/docs/images/dingding.jpg
index 6ac3ab8..9c9166c 100644
--- a/docs/images/dingding.jpg
+++ b/docs/images/dingding.jpg
Binary files differ
diff --git a/docs/index.rst b/docs/index.rst
index cb98f35..87e3a25 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -77,11 +77,12 @@
.. toctree::
:maxdepth: 1
- :caption: Benchmark and Leadboard
+ :caption: Benchmark and Leaderboard
./benchmark/benchmark_onnx.md
./benchmark/benchmark_onnx_cpp.md
./benchmark/benchmark_libtorch.md
+ ./benchmark/benchmark_pipeline_cer.md
.. toctree::
diff --git a/docs/installation/installation.md b/docs/installation/installation.md
index d020b51..f81ae83 100755
--- a/docs/installation/installation.md
+++ b/docs/installation/installation.md
@@ -32,7 +32,7 @@
### Install Pytorch (version >= 1.11.0):
```sh
-pip install torch torchaudio
+pip3 install torch torchaudio
```
If there exists CUDAs in your environments, you should install the pytorch with the version matching the CUDA. The matching list could be found in [docs](https://pytorch.org/get-started/previous-versions/).
### Install funasr
@@ -40,27 +40,27 @@
#### Install from pip
```shell
-pip install -U funasr
+pip3 install -U funasr
# For the users in China, you could install with the command:
-# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
#### Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install -e ./
+pip3 install -e ./
# For the users in China, you could install with the command:
-# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### Install modelscope (Optional)
If you want to use the pretrained models in ModelScope, you should install the modelscope:
```shell
-pip install -U modelscope
+pip3 install -U modelscope
# For the users in China, you could install with the command:
-# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### FQA
diff --git a/docs/m2met2/Challenge_result.md b/docs/m2met2/Challenge_result.md
new file mode 100644
index 0000000..52bbedd
--- /dev/null
+++ b/docs/m2met2/Challenge_result.md
@@ -0,0 +1,14 @@
+# Challenge Result
+The following table shows the final results of the competition, where Sub-track1 represents the sub-track under fixed training condition and Sub-track 2 represents the sub-track under the open training condition. All result in this table is cp-CER (%). The rankings in the table are the combined rankings of the two sub-tracks as all teams' submissions met the requirements of the sub-track under fixed training condition.
+| Rank | Team Name | Sub-track1 | Sub-track2 | paper |
+|------|----------------------|------------|------------|------------------------|
+| 1 | Ximalaya Speech Team | 11.27 | 11.27 | |
+| 2 | 灏忛┈杈� | 18.64 | 18.64 | |
+| 3 | AIzyzx | 22.83 | 22.83 | |
+| 4 | AsrSpeeder | / | 23.51 | |
+| 5 | zyxlhz | 24.82 | 24.82 | |
+| 6 | CMCAI | 26.11 | / | |
+| 7 | Volcspeech | 34.21 | 34.21 | |
+| 8 | 閴村線鐭ユ潵 | 40.14 | 40.14 | |
+| 9 | baseline | 41.55 | 41.55 | |
+| 10 | DAICT | 41.64 | | |
diff --git a/docs/m2met2/_build/doctrees/Challenge_result.doctree b/docs/m2met2/_build/doctrees/Challenge_result.doctree
new file mode 100644
index 0000000..03c4e24
--- /dev/null
+++ b/docs/m2met2/_build/doctrees/Challenge_result.doctree
Binary files differ
diff --git a/docs/m2met2/_build/doctrees/environment.pickle b/docs/m2met2/_build/doctrees/environment.pickle
index 3002d02..44a2ec5 100644
--- a/docs/m2met2/_build/doctrees/environment.pickle
+++ b/docs/m2met2/_build/doctrees/environment.pickle
Binary files differ
diff --git a/docs/m2met2/_build/doctrees/index.doctree b/docs/m2met2/_build/doctrees/index.doctree
index 9469f3c..a6c54f4 100644
--- a/docs/m2met2/_build/doctrees/index.doctree
+++ b/docs/m2met2/_build/doctrees/index.doctree
Binary files differ
diff --git a/docs/m2met2/_build/html/Baseline.html b/docs/m2met2/_build/html/Baseline.html
index c578602..4426a41 100644
--- a/docs/m2met2/_build/html/Baseline.html
+++ b/docs/m2met2/_build/html/Baseline.html
@@ -100,6 +100,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/Challenge_result.html b/docs/m2met2/_build/html/Challenge_result.html
new file mode 100644
index 0000000..226e671
--- /dev/null
+++ b/docs/m2met2/_build/html/Challenge_result.html
@@ -0,0 +1,247 @@
+
+<!DOCTYPE html>
+
+<html lang="en">
+ <head>
+ <meta charset="utf-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
+
+
+ <!-- Licensed under the Apache 2.0 License -->
+ <link rel="stylesheet" type="text/css" href="_static/fonts/open-sans/stylesheet.css" />
+ <!-- Licensed under the SIL Open Font License -->
+ <link rel="stylesheet" type="text/css" href="_static/fonts/source-serif-pro/source-serif-pro.css" />
+ <link rel="stylesheet" type="text/css" href="_static/css/bootstrap.min.css" />
+ <link rel="stylesheet" type="text/css" href="_static/css/bootstrap-theme.min.css" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+
+ <title>Challenge Result — MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0</title>
+ <link rel="stylesheet" type="text/css" href="_static/pygments.css" />
+ <link rel="stylesheet" type="text/css" href="_static/guzzle.css" />
+ <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
+ <script src="_static/jquery.js"></script>
+ <script src="_static/underscore.js"></script>
+ <script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
+ <script src="_static/doctools.js"></script>
+ <script src="_static/sphinx_highlight.js"></script>
+ <link rel="index" title="Index" href="genindex.html" />
+ <link rel="search" title="Search" href="search.html" />
+ <link rel="next" title="Organizers" href="Organizers.html" />
+ <link rel="prev" title="Rules" href="Rules.html" />
+
+
+
+ </head><body>
+ <div class="related" role="navigation" aria-label="related navigation">
+ <h3>Navigation</h3>
+ <ul>
+ <li class="right" style="margin-right: 10px">
+ <a href="genindex.html" title="General Index"
+ accesskey="I">index</a></li>
+ <li class="right" >
+ <a href="Organizers.html" title="Organizers"
+ accesskey="N">next</a> |</li>
+ <li class="right" >
+ <a href="Rules.html" title="Rules"
+ accesskey="P">previous</a> |</li>
+ <li class="nav-item nav-item-0"><a href="index.html">MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0</a> »</li>
+ <li class="nav-item nav-item-this"><a href="">Challenge Result</a></li>
+ </ul>
+ </div>
+ <div class="container-wrapper">
+
+ <div id="mobile-toggle">
+ <a href="#"><span class="glyphicon glyphicon-align-justify" aria-hidden="true"></span></a>
+ </div>
+ <div id="left-column">
+ <div class="sphinxsidebar"><a href="
+ index.html" class="text-logo">MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0</a>
+<div class="sidebar-block">
+ <div class="sidebar-wrapper">
+ <div id="main-search">
+ <form class="form-inline" action="search.html" method="GET" role="form">
+ <div class="input-group">
+ <input name="q" type="text" class="form-control" placeholder="Search...">
+ </div>
+ <input type="hidden" name="check_keywords" value="yes" />
+ <input type="hidden" name="area" value="default" />
+ </form>
+ </div>
+ </div>
+</div>
+<div class="sidebar-block">
+ <div class="sidebar-toc">
+
+
+ <p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
+<ul class="current">
+<li class="toctree-l1"><a class="reference internal" href="Introduction.html">Introduction</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="Introduction.html#call-for-participation">Call for participation</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Introduction.html#timeline-aoe-time">Timeline(AOE Time)</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Introduction.html#guidelines">Guidelines</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="Dataset.html">Datasets</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="Dataset.html#overview-of-training-data">Overview of training data</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Dataset.html#detail-of-alimeeting-corpus">Detail of AliMeeting corpus</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Dataset.html#get-the-data">Get the data</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="Track_setting_and_evaluation.html">Track & Evaluation</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="Track_setting_and_evaluation.html#speaker-attributed-asr">Speaker-Attributed ASR</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Track_setting_and_evaluation.html#evaluation-metric">Evaluation metric</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Track_setting_and_evaluation.html#sub-track-arrangement">Sub-track arrangement</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="Baseline.html">Baseline</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="Baseline.html#overview">Overview</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Baseline.html#quick-start">Quick start</a></li>
+<li class="toctree-l2"><a class="reference internal" href="Baseline.html#baseline-results">Baseline results</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1 current"><a class="current reference internal" href="#">Challenge Result</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
+</ul>
+
+
+ </div>
+</div>
+
+ </div>
+ </div>
+ <div id="right-column">
+
+ <div role="navigation" aria-label="breadcrumbs navigation">
+ <ol class="breadcrumb">
+ <li><a href="index.html">Docs</a></li>
+
+ <li>Challenge Result</li>
+ </ol>
+ </div>
+
+ <div class="document clearer body">
+
+ <section id="challenge-result">
+<h1>Challenge Result<a class="headerlink" href="#challenge-result" title="Permalink to this heading">露</a></h1>
+<p>The following table shows the final results of the competition, where Sub-track1 represents the sub-track under fixed training condition and Sub-track 2 represents the sub-track under the open training condition. All result in this table is cp-CER (%). The rankings in the table are the combined rankings of the two sub-tracks as all teams鈥� submissions met the requirements of the sub-track under fixed training condition.</p>
+<table class="docutils align-default">
+<thead>
+<tr class="row-odd"><th class="head"><p>Rank 聽 聽</p></th>
+<th class="head"><p>Team Name 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽</p></th>
+<th class="head"><p>Sub-track1 聽 聽</p></th>
+<th class="head"><p>Sub-track2 聽 聽</p></th>
+<th class="head"><p>paper</p></th>
+</tr>
+</thead>
+<tbody>
+<tr class="row-even"><td><p>1</p></td>
+<td><p>Ximalaya Speech Team</p></td>
+<td><p>11.27</p></td>
+<td><p>11.27</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>2</p></td>
+<td><p>灏忛┈杈�</p></td>
+<td><p>18.64</p></td>
+<td><p>18.64</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>3</p></td>
+<td><p>AIzyzx</p></td>
+<td><p>22.83</p></td>
+<td><p>22.83</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>4</p></td>
+<td><p>AsrSpeeder</p></td>
+<td><p>/</p></td>
+<td><p>23.51</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>5</p></td>
+<td><p>zyxlhz</p></td>
+<td><p>24.82</p></td>
+<td><p>24.82</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>6</p></td>
+<td><p>CMCAI</p></td>
+<td><p>26.11</p></td>
+<td><p>/</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>7</p></td>
+<td><p>Volcspeech</p></td>
+<td><p>34.21</p></td>
+<td><p>34.21</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>8</p></td>
+<td><p>閴村線鐭ユ潵</p></td>
+<td><p>40.14</p></td>
+<td><p>40.14</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>9</p></td>
+<td><p>baseline</p></td>
+<td><p>41.55</p></td>
+<td><p>41.55</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>10</p></td>
+<td><p>DAICT</p></td>
+<td><p>41.64</p></td>
+<td><p></p></td>
+<td><p></p></td>
+</tr>
+</tbody>
+</table>
+</section>
+
+
+ </div>
+
+ <div class="footer-relations">
+
+ <div class="pull-left">
+ <a class="btn btn-default" href="Rules.html" title="previous chapter (use the left arrow)">Rules</a>
+ </div>
+
+ <div class="pull-right">
+ <a class="btn btn-default" href="Organizers.html" title="next chapter (use the right arrow)">Organizers</a>
+ </div>
+ </div>
+ <div class="clearer"></div>
+
+ </div>
+ <div class="clearfix"></div>
+ </div>
+ <div class="related" role="navigation" aria-label="related navigation">
+ <h3>Navigation</h3>
+ <ul>
+ <li class="right" style="margin-right: 10px">
+ <a href="genindex.html" title="General Index"
+ >index</a></li>
+ <li class="right" >
+ <a href="Organizers.html" title="Organizers"
+ >next</a> |</li>
+ <li class="right" >
+ <a href="Rules.html" title="Rules"
+ >previous</a> |</li>
+ <li class="nav-item nav-item-0"><a href="index.html">MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0</a> »</li>
+ <li class="nav-item nav-item-this"><a href="">Challenge Result</a></li>
+ </ul>
+ </div>
+<script type="text/javascript">
+ $("#mobile-toggle a").click(function () {
+ $("#left-column").toggle();
+ });
+</script>
+<script type="text/javascript" src="_static/js/bootstrap.js"></script>
+ <div class="footer">
+ © Copyright 2023, Speech Lab, Alibaba Group; ASLP Group, Northwestern Polytechnical University. Created using <a href="http://sphinx.pocoo.org/">Sphinx</a>.
+ </div>
+ </body>
+</html>
\ No newline at end of file
diff --git a/docs/m2met2/_build/html/Contact.html b/docs/m2met2/_build/html/Contact.html
index f268ef4..6596b3e 100644
--- a/docs/m2met2/_build/html/Contact.html
+++ b/docs/m2met2/_build/html/Contact.html
@@ -96,6 +96,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/Dataset.html b/docs/m2met2/_build/html/Dataset.html
index f6b2a04..9eb62a5 100644
--- a/docs/m2met2/_build/html/Dataset.html
+++ b/docs/m2met2/_build/html/Dataset.html
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/Introduction.html b/docs/m2met2/_build/html/Introduction.html
index 82394fc..1e541f2 100644
--- a/docs/m2met2/_build/html/Introduction.html
+++ b/docs/m2met2/_build/html/Introduction.html
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/Organizers.html b/docs/m2met2/_build/html/Organizers.html
index e500019..6d89513 100644
--- a/docs/m2met2/_build/html/Organizers.html
+++ b/docs/m2met2/_build/html/Organizers.html
@@ -27,7 +27,7 @@
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="Contact" href="Contact.html" />
- <link rel="prev" title="Rules" href="Rules.html" />
+ <link rel="prev" title="Challenge Result" href="Challenge_result.html" />
@@ -42,7 +42,7 @@
<a href="Contact.html" title="Contact"
accesskey="N">next</a> |</li>
<li class="right" >
- <a href="Rules.html" title="Rules"
+ <a href="Challenge_result.html" title="Challenge Result"
accesskey="P">previous</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0</a> »</li>
<li class="nav-item nav-item-this"><a href="">Organizers</a></li>
@@ -100,6 +100,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
@@ -156,7 +157,7 @@
<div class="footer-relations">
<div class="pull-left">
- <a class="btn btn-default" href="Rules.html" title="previous chapter (use the left arrow)">Rules</a>
+ <a class="btn btn-default" href="Challenge_result.html" title="previous chapter (use the left arrow)">Challenge Result</a>
</div>
<div class="pull-right">
@@ -178,7 +179,7 @@
<a href="Contact.html" title="Contact"
>next</a> |</li>
<li class="right" >
- <a href="Rules.html" title="Rules"
+ <a href="Challenge_result.html" title="Challenge Result"
>previous</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0</a> »</li>
<li class="nav-item nav-item-this"><a href="">Organizers</a></li>
diff --git a/docs/m2met2/_build/html/Rules.html b/docs/m2met2/_build/html/Rules.html
index 01f79cb..29b58e6 100644
--- a/docs/m2met2/_build/html/Rules.html
+++ b/docs/m2met2/_build/html/Rules.html
@@ -26,7 +26,7 @@
<script src="_static/sphinx_highlight.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
- <link rel="next" title="Organizers" href="Organizers.html" />
+ <link rel="next" title="Challenge Result" href="Challenge_result.html" />
<link rel="prev" title="Baseline" href="Baseline.html" />
@@ -39,7 +39,7 @@
<a href="genindex.html" title="General Index"
accesskey="I">index</a></li>
<li class="right" >
- <a href="Organizers.html" title="Organizers"
+ <a href="Challenge_result.html" title="Challenge Result"
accesskey="N">next</a> |</li>
<li class="right" >
<a href="Baseline.html" title="Baseline"
@@ -100,6 +100,7 @@
</ul>
</li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
@@ -145,7 +146,7 @@
</div>
<div class="pull-right">
- <a class="btn btn-default" href="Organizers.html" title="next chapter (use the right arrow)">Organizers</a>
+ <a class="btn btn-default" href="Challenge_result.html" title="next chapter (use the right arrow)">Challenge Result</a>
</div>
</div>
<div class="clearer"></div>
@@ -160,7 +161,7 @@
<a href="genindex.html" title="General Index"
>index</a></li>
<li class="right" >
- <a href="Organizers.html" title="Organizers"
+ <a href="Challenge_result.html" title="Challenge Result"
>next</a> |</li>
<li class="right" >
<a href="Baseline.html" title="Baseline"
diff --git a/docs/m2met2/_build/html/Track_setting_and_evaluation.html b/docs/m2met2/_build/html/Track_setting_and_evaluation.html
index 1cd72d9..49af652 100644
--- a/docs/m2met2/_build/html/Track_setting_and_evaluation.html
+++ b/docs/m2met2/_build/html/Track_setting_and_evaluation.html
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/_sources/Challenge_result.md.txt b/docs/m2met2/_build/html/_sources/Challenge_result.md.txt
new file mode 100644
index 0000000..52bbedd
--- /dev/null
+++ b/docs/m2met2/_build/html/_sources/Challenge_result.md.txt
@@ -0,0 +1,14 @@
+# Challenge Result
+The following table shows the final results of the competition, where Sub-track1 represents the sub-track under fixed training condition and Sub-track 2 represents the sub-track under the open training condition. All result in this table is cp-CER (%). The rankings in the table are the combined rankings of the two sub-tracks as all teams' submissions met the requirements of the sub-track under fixed training condition.
+| Rank | Team Name | Sub-track1 | Sub-track2 | paper |
+|------|----------------------|------------|------------|------------------------|
+| 1 | Ximalaya Speech Team | 11.27 | 11.27 | |
+| 2 | 灏忛┈杈� | 18.64 | 18.64 | |
+| 3 | AIzyzx | 22.83 | 22.83 | |
+| 4 | AsrSpeeder | / | 23.51 | |
+| 5 | zyxlhz | 24.82 | 24.82 | |
+| 6 | CMCAI | 26.11 | / | |
+| 7 | Volcspeech | 34.21 | 34.21 | |
+| 8 | 閴村線鐭ユ潵 | 40.14 | 40.14 | |
+| 9 | baseline | 41.55 | 41.55 | |
+| 10 | DAICT | 41.64 | | |
diff --git a/docs/m2met2/_build/html/_sources/index.rst.txt b/docs/m2met2/_build/html/_sources/index.rst.txt
index e0e7562..672cb91 100644
--- a/docs/m2met2/_build/html/_sources/index.rst.txt
+++ b/docs/m2met2/_build/html/_sources/index.rst.txt
@@ -18,5 +18,6 @@
./Track_setting_and_evaluation
./Baseline
./Rules
+ ./Challenge_result
./Organizers
./Contact
diff --git a/docs/m2met2/_build/html/genindex.html b/docs/m2met2/_build/html/genindex.html
index b331f6f..57e01f5 100644
--- a/docs/m2met2/_build/html/genindex.html
+++ b/docs/m2met2/_build/html/genindex.html
@@ -91,6 +91,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/index.html b/docs/m2met2/_build/html/index.html
index dd2a9cc..ef4d627 100644
--- a/docs/m2met2/_build/html/index.html
+++ b/docs/m2met2/_build/html/index.html
@@ -96,6 +96,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
@@ -131,6 +132,7 @@
<li class="toctree-l1"><a class="reference internal" href="Track_setting_and_evaluation.html">Track & Evaluation</a></li>
<li class="toctree-l1"><a class="reference internal" href="Baseline.html">Baseline</a></li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/objects.inv b/docs/m2met2/_build/html/objects.inv
index d8a5ea5..549117e 100644
--- a/docs/m2met2/_build/html/objects.inv
+++ b/docs/m2met2/_build/html/objects.inv
Binary files differ
diff --git a/docs/m2met2/_build/html/search.html b/docs/m2met2/_build/html/search.html
index f91b51a..b6ad6be 100644
--- a/docs/m2met2/_build/html/search.html
+++ b/docs/m2met2/_build/html/search.html
@@ -84,6 +84,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="Rules.html">Rules</a></li>
+<li class="toctree-l1"><a class="reference internal" href="Challenge_result.html">Challenge Result</a></li>
<li class="toctree-l1"><a class="reference internal" href="Organizers.html">Organizers</a></li>
<li class="toctree-l1"><a class="reference internal" href="Contact.html">Contact</a></li>
</ul>
diff --git a/docs/m2met2/_build/html/searchindex.js b/docs/m2met2/_build/html/searchindex.js
index 3387db5..5e65c9d 100644
--- a/docs/m2met2/_build/html/searchindex.js
+++ b/docs/m2met2/_build/html/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["Baseline", "Contact", "Dataset", "Introduction", "Organizers", "Rules", "Track_setting_and_evaluation", "index"], "filenames": ["Baseline.md", "Contact.md", "Dataset.md", "Introduction.md", "Organizers.md", "Rules.md", "Track_setting_and_evaluation.md", "index.rst"], "titles": ["Baseline", "Contact", "Datasets", "Introduction", "Organizers", "Rules", "Track & Evaluation", "ASRU 2023 MULTI-CHANNEL MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0 (M2MeT2.0)"], "terms": {"we": [0, 2, 3, 7], "releas": [0, 2, 3, 6], "an": [0, 2, 3, 6], "e2": 0, "sa": 0, "asr": [0, 3, 7], "conduct": [0, 2], "funasr": 0, "time": [0, 6], "accord": [0, 3], "timelin": [0, 2], "The": [0, 2, 3, 5, 6], "model": [0, 2, 3, 5, 6], "architectur": 0, "i": [0, 2, 3, 5], "shown": [0, 2], "figur": [0, 6], "3": [0, 2, 3], "speakerencod": 0, "initi": 0, "pre": [0, 6], "train": [0, 3, 5, 7], "speaker": [0, 2, 3, 7], "verif": 0, "from": [0, 2, 3, 5, 6], "modelscop": [0, 6], "thi": [0, 3, 5, 6], "also": [0, 2, 3, 6], "us": [0, 2, 5, 6], "extract": 0, "embed": 0, "profil": 0, "To": [0, 2, 3, 7], "run": 0, "first": 0, "you": [0, 1], "need": 0, "instal": 0, "There": [0, 2], "ar": [0, 2, 3, 5, 6, 7], "two": [0, 3, 5, 7], "startup": 0, "script": [0, 2], "sh": 0, "evalu": [0, 2, 3, 7], "old": 0, "eval": [0, 2, 5, 6], "test": [0, 2, 3, 5, 6], "set": [0, 2, 3, 5, 6], "run_m2met_2023_inf": 0, "infer": 0, "new": [0, 2, 3, 6], "multi": [0, 3, 6], "channel": [0, 3], "parti": [0, 3, 6], "meet": [0, 2, 3, 6], "transcript": [0, 2, 3, 5, 6], "2": [0, 2, 6], "0": [0, 1, 2, 3], "m2met2": [0, 1, 3], "challeng": [0, 1, 3, 5, 6], "befor": 0, "must": [0, 3, 5, 6], "manual": [0, 6], "download": [0, 2], "unpack": 0, "alimeet": [0, 1, 6], "corpu": [0, 6], "place": [0, 2], "dataset": [0, 3, 5, 6, 7], "directori": 0, "eval_ali_far": 0, "eval_ali_near": 0, "test_ali_far": 0, "test_ali_near": 0, "train_ali_far": 0, "train_ali_near": 0, "test_2023_ali_far": 0, "after": 0, "which": [0, 2, 3, 6], "contain": [0, 2, 6], "onli": [0, 2, 5, 6], "raw": 0, "audio": [0, 2, 3, 6], "Then": 0, "put": 0, "given": 0, "wav": 0, "scp": 0, "wav_raw": 0, "segment": [0, 2, 6], "utt2spk": 0, "spk2utt": 0, "data": [0, 3, 5, 6], "For": [0, 2], "more": [0, 2], "detail": [0, 3, 6], "can": [0, 2, 3, 5, 6], "see": 0, "here": 0, "system": [0, 3, 5, 6, 7], "tabl": [0, 2], "adopt": 0, "oracl": [0, 6], "dure": [0, 2, 6], "howev": [0, 3, 6], "due": [0, 3], "lack": 0, "label": [0, 5, 6], "provid": [0, 2, 6, 7], "addit": [0, 6], "spectral": 0, "cluster": 0, "meanwhil": 0, "show": 0, "impact": 0, "accuraci": [0, 6], "If": [1, 5, 6], "have": [1, 3], "ani": [1, 5, 6], "question": 1, "about": [1, 3], "pleas": 1, "u": [1, 2], "email": [1, 3, 4], "m2met": [1, 3, 6, 7], "gmail": 1, "com": [1, 4], "wechat": [1, 3], "group": [1, 2, 3], "In": [2, 3, 5], "fix": [2, 3, 7], "condit": [2, 3, 7], "restrict": 2, "three": [2, 3, 6], "publicli": [2, 6], "avail": [2, 6], "corpora": 2, "name": 2, "aishel": [2, 4, 6], "4": [2, 6], "cn": [2, 4, 6], "celeb": [2, 6], "perform": [2, 3], "call": 2, "2023": [2, 3, 5, 6], "score": [2, 6], "rank": [2, 3, 6], "describ": 2, "118": 2, "75": 2, "hour": [2, 3, 6], "speech": [2, 3, 6, 7], "total": [2, 6], "divid": [2, 6], "104": 2, "10": [2, 3, 6], "specif": [2, 6], "212": 2, "8": 2, "20": [2, 3], "session": [2, 3, 6, 7], "respect": 2, "each": [2, 3, 6], "consist": [2, 6], "15": 2, "30": 2, "minut": 2, "discuss": 2, "particip": [2, 5, 6], "number": [2, 3, 6], "456": 2, "25": 2, "60": 2, "balanc": 2, "gender": 2, "coverag": 2, "collect": 2, "13": 2, "venu": 2, "categor": 2, "type": 2, "small": 2, "medium": 2, "larg": [2, 3], "room": [2, 3], "size": 2, "rang": 2, "m": 2, "55": 2, "differ": [2, 3, 6], "give": 2, "varieti": 2, "acoust": [2, 3, 6], "properti": 2, "layout": 2, "paramet": [2, 5], "togeth": 2, "wall": 2, "materi": 2, "cover": 2, "cement": 2, "glass": 2, "etc": 2, "other": 2, "furnish": 2, "includ": [2, 3, 5, 6], "sofa": 2, "tv": 2, "blackboard": 2, "fan": 2, "air": 2, "condition": 2, "plant": 2, "record": [2, 6], "sit": 2, "around": 2, "microphon": [2, 3], "arrai": [2, 3], "natur": 2, "convers": 2, "distanc": 2, "5": 2, "all": [2, 3, 5, 6], "nativ": 2, "chines": 2, "speak": [2, 3], "mandarin": [2, 3], "without": 2, "strong": 2, "accent": 2, "variou": [2, 3], "kind": 2, "indoor": 2, "nois": [2, 3, 5], "limit": [2, 3, 5], "click": 2, "keyboard": 2, "door": 2, "open": [2, 3, 7], "close": [2, 3], "bubbl": 2, "made": [2, 3], "both": [2, 6], "requir": [2, 3, 6], "remain": [2, 3], "same": [2, 5], "posit": 2, "overlap": [2, 3], "between": [2, 6], "exampl": 2, "fig": 2, "1": 2, "within": [2, 3], "one": [2, 5], "ensur": 2, "ratio": 2, "select": [2, 3, 5, 6], "topic": 2, "medic": 2, "treatment": 2, "educ": 2, "busi": 2, "organ": [2, 3, 5, 6, 7], "manag": 2, "industri": [2, 3], "product": 2, "daili": 2, "routin": 2, "averag": 2, "42": 2, "27": 2, "34": 2, "76": 2, "A": [2, 4], "distribut": 2, "were": 2, "ident": [2, 6], "compris": [2, 3, 7], "therebi": 2, "share": 2, "similar": 2, "configur": 2, "field": [2, 3, 6], "signal": [2, 3], "headset": 2, "": [2, 6], "own": 2, "transcrib": [2, 3, 6], "It": [2, 6], "worth": [2, 6], "note": [2, 6], "far": [2, 3], "synchron": 2, "common": 2, "prepar": 2, "textgrid": 2, "format": 2, "inform": [2, 3], "durat": 2, "id": 2, "timestamp": [2, 6], "mention": 2, "abov": 2, "openslr": 2, "via": 2, "follow": [2, 5], "link": 2, "particularli": 2, "baselin": [2, 3, 7], "conveni": 2, "automat": [3, 7], "recognit": [3, 7], "diariz": 3, "signific": 3, "stride": 3, "recent": 3, "year": 3, "result": 3, "surg": 3, "technologi": 3, "applic": 3, "across": 3, "domain": 3, "present": 3, "uniqu": [3, 6], "complex": [3, 5], "divers": 3, "style": 3, "variabl": 3, "confer": 3, "environment": 3, "reverber": [3, 5], "over": 3, "sever": 3, "been": 3, "advanc": [3, 7], "develop": [3, 6], "rich": 3, "comput": [3, 5], "hear": 3, "multisourc": 3, "environ": 3, "chime": 3, "latest": 3, "iter": 3, "ha": 3, "particular": 3, "focu": 3, "distant": 3, "gener": 3, "topologi": 3, "scenario": 3, "while": 3, "progress": 3, "english": 3, "languag": [3, 5], "barrier": 3, "achiev": 3, "compar": 3, "non": 3, "multimod": 3, "base": 3, "process": [3, 6], "misp": 3, "instrument": 3, "seek": 3, "address": 3, "problem": 3, "visual": 3, "everydai": 3, "home": 3, "focus": 3, "tackl": 3, "issu": 3, "offlin": 3, "icassp2022": 3, "main": 3, "task": [3, 6, 7], "former": 3, "involv": [3, 6], "identifi": 3, "who": 3, "spoke": 3, "when": 3, "latter": 3, "aim": 3, "multipl": [3, 6], "simultan": 3, "pose": [3, 6], "technic": 3, "difficulti": 3, "interfer": 3, "build": [3, 6, 7], "success": [3, 7], "previou": 3, "excit": 3, "propos": [3, 7], "asru": 3, "special": [3, 5, 7], "origin": [3, 5], "metric": [3, 7], "wa": [3, 6], "independ": 3, "meant": 3, "could": 3, "determin": 3, "correspond": [3, 5], "further": 3, "current": [3, 7], "talker": [3, 7], "toward": 3, "practic": 3, "attribut": [3, 7], "sub": [3, 5, 7], "track": [3, 5, 7], "what": 3, "facilit": [3, 7], "reproduc": [3, 7], "research": [3, 4, 7], "offer": 3, "comprehens": [3, 7], "overview": [3, 7], "rule": [3, 7], "furthermor": 3, "carefulli": 3, "curat": 3, "approxim": [3, 6], "design": 3, "enabl": 3, "valid": 3, "state": [3, 6, 7], "art": [3, 7], "area": 3, "april": 3, "29": 3, "registr": 3, "mai": 3, "11": 3, "22": 3, "deadlin": 3, "date": 3, "join": 3, "june": 3, "16": 3, "leaderboard": 3, "final": [3, 5, 6], "submiss": 3, "leaderboar": 3, "26": 3, "juli": 3, "paper": [3, 6], "decemb": 3, "12": 3, "workshop": 3, "interest": 3, "whether": 3, "academia": 3, "regist": 3, "complet": 3, "googl": 3, "form": 3, "below": 3, "welcom": 3, "keep": 3, "up": 3, "updat": 3, "work": 3, "dai": 3, "send": 3, "invit": 3, "elig": [3, 5], "team": 3, "qualifi": 3, "adher": [3, 5], "publish": 3, "page": 3, "prior": 3, "submit": 3, "descript": [3, 6], "document": 3, "approach": [3, 5], "method": 3, "top": 3, "asru2023": [3, 7], "proceed": 3, "lei": 4, "xie": 4, "professor": 4, "foundat": 4, "china": 4, "lxie": 4, "nwpu": 4, "edu": 4, "kong": 4, "aik": 4, "lee": 4, "senior": 4, "scientist": 4, "institut": 4, "infocomm": 4, "star": 4, "singapor": 4, "kongaik": 4, "ieee": 4, "org": 4, "zhiji": 4, "yan": 4, "princip": 4, "engin": 4, "alibaba": 4, "yzj": 4, "inc": 4, "shiliang": 4, "zhang": 4, "sly": 4, "zsl": 4, "yanmin": 4, "qian": 4, "shanghai": 4, "jiao": 4, "tong": 4, "univers": 4, "yanminqian": 4, "sjtu": 4, "zhuo": 4, "chen": 4, "appli": 4, "microsoft": 4, "usa": 4, "zhuc": 4, "jian": 4, "wu": 4, "wujian": 4, "hui": 4, "bu": 4, "ceo": 4, "buhui": 4, "aishelldata": 4, "should": 5, "augment": 5, "allow": [5, 6], "ad": 5, "speed": 5, "perturb": 5, "tone": 5, "chang": 5, "permit": 5, "purpos": 5, "instead": [5, 6], "util": [5, 6], "tune": 5, "violat": 5, "strictli": [5, 6], "prohibit": [5, 6], "fine": 5, "cpcer": [5, 6], "lower": 5, "judg": 5, "superior": 5, "forc": 5, "align": 5, "obtain": [5, 6], "frame": 5, "level": 5, "classif": 5, "basi": 5, "shallow": 5, "fusion": 5, "end": 5, "e": [5, 6], "g": 5, "la": 5, "rnnt": 5, "transform": [5, 6], "come": 5, "right": 5, "interpret": 5, "belong": 5, "case": 5, "circumst": 5, "coordin": 5, "assign": 6, "illustr": 6, "aishell4": 6, "constrain": 6, "sourc": 6, "addition": 6, "soon": 6, "simpl": 6, "voic": 6, "activ": 6, "detect": 6, "vad": 6, "concaten": 6, "minimum": 6, "permut": 6, "charact": 6, "error": 6, "rate": 6, "calcul": 6, "step": 6, "firstli": 6, "refer": 6, "hypothesi": 6, "chronolog": 6, "order": 6, "secondli": 6, "cer": 6, "repeat": 6, "possibl": 6, "lowest": 6, "tthe": 6, "insert": 6, "Ins": 6, "substitut": 6, "delet": 6, "del": 6, "output": 6, "text": 6, "frac": 6, "mathcal": 6, "n_": 6, "100": 6, "where": 6, "usag": 6, "third": 6, "hug": 6, "face": 6, "list": 6, "clearli": 6, "privat": 6, "simul": 6, "thei": 6, "mandatori": 6, "clear": 6, "scheme": 6, "delight": 7, "introduct": 7, "contact": 7}, "objects": {}, "objtypes": {}, "objnames": {}, "titleterms": {"baselin": 0, "overview": [0, 2], "quick": 0, "start": 0, "result": 0, "contact": 1, "dataset": 2, "train": [2, 6], "data": 2, "detail": 2, "alimeet": 2, "corpu": 2, "get": 2, "introduct": 3, "call": 3, "particip": 3, "timelin": 3, "aoe": 3, "time": 3, "guidelin": 3, "organ": 4, "rule": 5, "track": 6, "evalu": 6, "speaker": 6, "attribut": 6, "asr": 6, "metric": 6, "sub": 6, "arrang": 6, "i": 6, "fix": 6, "condit": 6, "ii": 6, "open": 6, "asru": 7, "2023": 7, "multi": 7, "channel": 7, "parti": 7, "meet": 7, "transcript": 7, "challeng": 7, "2": 7, "0": 7, "m2met2": 7, "content": 7}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 57}, "alltitles": {"Baseline": [[0, "baseline"]], "Overview": [[0, "overview"]], "Quick start": [[0, "quick-start"]], "Baseline results": [[0, "baseline-results"]], "Contact": [[1, "contact"]], "Datasets": [[2, "datasets"]], "Overview of training data": [[2, "overview-of-training-data"]], "Detail of AliMeeting corpus": [[2, "detail-of-alimeeting-corpus"]], "Get the data": [[2, "get-the-data"]], "Introduction": [[3, "introduction"]], "Call for participation": [[3, "call-for-participation"]], "Timeline(AOE Time)": [[3, "timeline-aoe-time"]], "Guidelines": [[3, "guidelines"]], "Organizers": [[4, "organizers"]], "Rules": [[5, "rules"]], "Track & Evaluation": [[6, "track-evaluation"]], "Speaker-Attributed ASR": [[6, "speaker-attributed-asr"]], "Evaluation metric": [[6, "evaluation-metric"]], "Sub-track arrangement": [[6, "sub-track-arrangement"]], "Sub-track I (Fixed Training Condition):": [[6, "sub-track-i-fixed-training-condition"]], "Sub-track II (Open Training Condition):": [[6, "sub-track-ii-open-training-condition"]], "ASRU 2023 MULTI-CHANNEL MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0 (M2MeT2.0)": [[7, "asru-2023-multi-channel-multi-party-meeting-transcription-challenge-2-0-m2met2-0"]], "Contents:": [[7, null]]}, "indexentries": {}})
\ No newline at end of file
+Search.setIndex({"docnames": ["Baseline", "Challenge_result", "Contact", "Dataset", "Introduction", "Organizers", "Rules", "Track_setting_and_evaluation", "index"], "filenames": ["Baseline.md", "Challenge_result.md", "Contact.md", "Dataset.md", "Introduction.md", "Organizers.md", "Rules.md", "Track_setting_and_evaluation.md", "index.rst"], "titles": ["Baseline", "Challenge Result", "Contact", "Datasets", "Introduction", "Organizers", "Rules", "Track & Evaluation", "ASRU 2023 MULTI-CHANNEL MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0 (M2MeT2.0)"], "terms": {"we": [0, 3, 4, 8], "releas": [0, 3, 4, 7], "an": [0, 3, 4, 7], "e2": 0, "sa": 0, "asr": [0, 4, 8], "conduct": [0, 3], "funasr": 0, "time": [0, 7], "accord": [0, 4], "timelin": [0, 3], "The": [0, 1, 3, 4, 6, 7], "model": [0, 3, 4, 6, 7], "architectur": 0, "i": [0, 1, 3, 4, 6], "shown": [0, 3], "figur": [0, 7], "3": [0, 1, 3, 4], "speakerencod": 0, "initi": 0, "pre": [0, 7], "train": [0, 1, 4, 6, 8], "speaker": [0, 3, 4, 8], "verif": 0, "from": [0, 3, 4, 6, 7], "modelscop": [0, 7], "thi": [0, 1, 4, 6, 7], "also": [0, 3, 4, 7], "us": [0, 3, 6, 7], "extract": 0, "embed": 0, "profil": 0, "To": [0, 3, 4, 8], "run": 0, "first": 0, "you": [0, 2], "need": 0, "instal": 0, "There": [0, 3], "ar": [0, 1, 3, 4, 6, 7, 8], "two": [0, 1, 4, 6, 8], "startup": 0, "script": [0, 3], "sh": 0, "evalu": [0, 3, 4, 8], "old": 0, "eval": [0, 3, 6, 7], "test": [0, 3, 4, 6, 7], "set": [0, 3, 4, 6, 7], "run_m2met_2023_inf": 0, "infer": 0, "new": [0, 3, 4, 7], "multi": [0, 4, 7], "channel": [0, 4], "parti": [0, 4, 7], "meet": [0, 3, 4, 7], "transcript": [0, 3, 4, 6, 7], "2": [0, 1, 3, 7], "0": [0, 2, 3, 4], "m2met2": [0, 2, 4], "challeng": [0, 2, 4, 6, 7], "befor": 0, "must": [0, 4, 6, 7], "manual": [0, 7], "download": [0, 3], "unpack": 0, "alimeet": [0, 2, 7], "corpu": [0, 7], "place": [0, 3], "dataset": [0, 4, 6, 7, 8], "directori": 0, "eval_ali_far": 0, "eval_ali_near": 0, "test_ali_far": 0, "test_ali_near": 0, "train_ali_far": 0, "train_ali_near": 0, "test_2023_ali_far": 0, "after": 0, "which": [0, 3, 4, 7], "contain": [0, 3, 7], "onli": [0, 3, 6, 7], "raw": 0, "audio": [0, 3, 4, 7], "Then": 0, "put": 0, "given": 0, "wav": 0, "scp": 0, "wav_raw": 0, "segment": [0, 3, 7], "utt2spk": 0, "spk2utt": 0, "data": [0, 4, 6, 7], "For": [0, 3], "more": [0, 3], "detail": [0, 4, 7], "can": [0, 3, 4, 6, 7], "see": 0, "here": 0, "system": [0, 4, 6, 7, 8], "tabl": [0, 1, 3], "adopt": 0, "oracl": [0, 7], "dure": [0, 3, 7], "howev": [0, 4, 7], "due": [0, 4], "lack": 0, "label": [0, 6, 7], "provid": [0, 3, 7, 8], "addit": [0, 7], "spectral": 0, "cluster": 0, "meanwhil": 0, "show": [0, 1], "impact": 0, "accuraci": [0, 7], "follow": [1, 3, 6], "final": [1, 4, 6, 7], "competit": 1, "where": [1, 7], "sub": [1, 4, 6, 8], "track1": 1, "repres": 1, "track": [1, 4, 6, 8], "under": 1, "fix": [1, 3, 4, 8], "condit": [1, 3, 4, 8], "open": [1, 3, 4, 8], "all": [1, 3, 4, 6, 7], "cp": 1, "cer": [1, 7], "rank": [1, 3, 4, 7], "combin": 1, "team": [1, 4], "submiss": [1, 4], "met": 1, "requir": [1, 3, 4, 7], "name": [1, 3], "track2": 1, "paper": [1, 4, 7], "1": [1, 3], "ximalaya": 1, "speech": [1, 3, 4, 7, 8], "11": [1, 4], "27": [1, 3], "\u5c0f\u9a6c\u8fbe": 1, "18": 1, "64": 1, "aizyzx": 1, "22": [1, 4], "83": 1, "4": [1, 3, 7], "asrspeed": 1, "23": 1, "51": 1, "5": [1, 3], "zyxlhz": 1, "24": 1, "82": 1, "6": 1, "cmcai": 1, "26": [1, 4], "7": 1, "volcspeech": 1, "34": [1, 3], "21": 1, "8": [1, 3], "\u9274\u5f80\u77e5\u6765": 1, "40": 1, "14": 1, "9": 1, "baselin": [1, 3, 4, 8], "41": 1, "55": [1, 3], "10": [1, 3, 4, 7], "daict": 1, "If": [2, 6, 7], "have": [2, 4], "ani": [2, 6, 7], "question": 2, "about": [2, 4], "pleas": 2, "u": [2, 3], "email": [2, 4, 5], "m2met": [2, 4, 7, 8], "gmail": 2, "com": [2, 5], "wechat": [2, 4], "group": [2, 3, 4], "In": [3, 4, 6], "restrict": 3, "three": [3, 4, 7], "publicli": [3, 7], "avail": [3, 7], "corpora": 3, "aishel": [3, 5, 7], "cn": [3, 5, 7], "celeb": [3, 7], "perform": [3, 4], "call": 3, "2023": [3, 4, 6, 7], "score": [3, 7], "describ": 3, "118": 3, "75": 3, "hour": [3, 4, 7], "total": [3, 7], "divid": [3, 7], "104": 3, "specif": [3, 7], "212": 3, "20": [3, 4], "session": [3, 4, 7, 8], "respect": 3, "each": [3, 4, 7], "consist": [3, 7], "15": 3, "30": 3, "minut": 3, "discuss": 3, "particip": [3, 6, 7], "number": [3, 4, 7], "456": 3, "25": 3, "60": 3, "balanc": 3, "gender": 3, "coverag": 3, "collect": 3, "13": 3, "venu": 3, "categor": 3, "type": 3, "small": 3, "medium": 3, "larg": [3, 4], "room": [3, 4], "size": 3, "rang": 3, "m": 3, "differ": [3, 4, 7], "give": 3, "varieti": 3, "acoust": [3, 4, 7], "properti": 3, "layout": 3, "paramet": [3, 6], "togeth": 3, "wall": 3, "materi": 3, "cover": 3, "cement": 3, "glass": 3, "etc": 3, "other": 3, "furnish": 3, "includ": [3, 4, 6, 7], "sofa": 3, "tv": 3, "blackboard": 3, "fan": 3, "air": 3, "condition": 3, "plant": 3, "record": [3, 7], "sit": 3, "around": 3, "microphon": [3, 4], "arrai": [3, 4], "natur": 3, "convers": 3, "distanc": 3, "nativ": 3, "chines": 3, "speak": [3, 4], "mandarin": [3, 4], "without": 3, "strong": 3, "accent": 3, "variou": [3, 4], "kind": 3, "indoor": 3, "nois": [3, 4, 6], "limit": [3, 4, 6], "click": 3, "keyboard": 3, "door": 3, "close": [3, 4], "bubbl": 3, "made": [3, 4], "both": [3, 7], "remain": [3, 4], "same": [3, 6], "posit": 3, "overlap": [3, 4], "between": [3, 7], "exampl": 3, "fig": 3, "within": [3, 4], "one": [3, 6], "ensur": 3, "ratio": 3, "select": [3, 4, 6, 7], "topic": 3, "medic": 3, "treatment": 3, "educ": 3, "busi": 3, "organ": [3, 4, 6, 7, 8], "manag": 3, "industri": [3, 4], "product": 3, "daili": 3, "routin": 3, "averag": 3, "42": 3, "76": 3, "A": [3, 5], "distribut": 3, "were": 3, "ident": [3, 7], "compris": [3, 4, 8], "therebi": 3, "share": 3, "similar": 3, "configur": 3, "field": [3, 4, 7], "signal": [3, 4], "headset": 3, "": [3, 7], "own": 3, "transcrib": [3, 4, 7], "It": [3, 7], "worth": [3, 7], "note": [3, 7], "far": [3, 4], "synchron": 3, "common": 3, "prepar": 3, "textgrid": 3, "format": 3, "inform": [3, 4], "durat": 3, "id": 3, "timestamp": [3, 7], "mention": 3, "abov": 3, "openslr": 3, "via": 3, "link": 3, "particularli": 3, "conveni": 3, "automat": [4, 8], "recognit": [4, 8], "diariz": 4, "signific": 4, "stride": 4, "recent": 4, "year": 4, "result": [4, 8], "surg": 4, "technologi": 4, "applic": 4, "across": 4, "domain": 4, "present": 4, "uniqu": [4, 7], "complex": [4, 6], "divers": 4, "style": 4, "variabl": 4, "confer": 4, "environment": 4, "reverber": [4, 6], "over": 4, "sever": 4, "been": 4, "advanc": [4, 8], "develop": [4, 7], "rich": 4, "comput": [4, 6], "hear": 4, "multisourc": 4, "environ": 4, "chime": 4, "latest": 4, "iter": 4, "ha": 4, "particular": 4, "focu": 4, "distant": 4, "gener": 4, "topologi": 4, "scenario": 4, "while": 4, "progress": 4, "english": 4, "languag": [4, 6], "barrier": 4, "achiev": 4, "compar": 4, "non": 4, "multimod": 4, "base": 4, "process": [4, 7], "misp": 4, "instrument": 4, "seek": 4, "address": 4, "problem": 4, "visual": 4, "everydai": 4, "home": 4, "focus": 4, "tackl": 4, "issu": 4, "offlin": 4, "icassp2022": 4, "main": 4, "task": [4, 7, 8], "former": 4, "involv": [4, 7], "identifi": 4, "who": 4, "spoke": 4, "when": 4, "latter": 4, "aim": 4, "multipl": [4, 7], "simultan": 4, "pose": [4, 7], "technic": 4, "difficulti": 4, "interfer": 4, "build": [4, 7, 8], "success": [4, 8], "previou": 4, "excit": 4, "propos": [4, 8], "asru": 4, "special": [4, 6, 8], "origin": [4, 6], "metric": [4, 8], "wa": [4, 7], "independ": 4, "meant": 4, "could": 4, "determin": 4, "correspond": [4, 6], "further": 4, "current": [4, 8], "talker": [4, 8], "toward": 4, "practic": 4, "attribut": [4, 8], "what": 4, "facilit": [4, 8], "reproduc": [4, 8], "research": [4, 5, 8], "offer": 4, "comprehens": [4, 8], "overview": [4, 8], "rule": [4, 8], "furthermor": 4, "carefulli": 4, "curat": 4, "approxim": [4, 7], "design": 4, "enabl": 4, "valid": 4, "state": [4, 7, 8], "art": [4, 8], "area": 4, "april": 4, "29": 4, "registr": 4, "mai": 4, "deadlin": 4, "date": 4, "join": 4, "june": 4, "16": 4, "leaderboard": 4, "leaderboar": 4, "juli": 4, "decemb": 4, "12": 4, "workshop": 4, "interest": 4, "whether": 4, "academia": 4, "regist": 4, "complet": 4, "googl": 4, "form": 4, "below": 4, "welcom": 4, "keep": 4, "up": 4, "updat": 4, "work": 4, "dai": 4, "send": 4, "invit": 4, "elig": [4, 6], "qualifi": 4, "adher": [4, 6], "publish": 4, "page": 4, "prior": 4, "submit": 4, "descript": [4, 7], "document": 4, "approach": [4, 6], "method": 4, "top": 4, "asru2023": [4, 8], "proceed": 4, "lei": 5, "xie": 5, "professor": 5, "foundat": 5, "china": 5, "lxie": 5, "nwpu": 5, "edu": 5, "kong": 5, "aik": 5, "lee": 5, "senior": 5, "scientist": 5, "institut": 5, "infocomm": 5, "star": 5, "singapor": 5, "kongaik": 5, "ieee": 5, "org": 5, "zhiji": 5, "yan": 5, "princip": 5, "engin": 5, "alibaba": 5, "yzj": 5, "inc": 5, "shiliang": 5, "zhang": 5, "sly": 5, "zsl": 5, "yanmin": 5, "qian": 5, "shanghai": 5, "jiao": 5, "tong": 5, "univers": 5, "yanminqian": 5, "sjtu": 5, "zhuo": 5, "chen": 5, "appli": 5, "microsoft": 5, "usa": 5, "zhuc": 5, "jian": 5, "wu": 5, "wujian": 5, "hui": 5, "bu": 5, "ceo": 5, "buhui": 5, "aishelldata": 5, "should": 6, "augment": 6, "allow": [6, 7], "ad": 6, "speed": 6, "perturb": 6, "tone": 6, "chang": 6, "permit": 6, "purpos": 6, "instead": [6, 7], "util": [6, 7], "tune": 6, "violat": 6, "strictli": [6, 7], "prohibit": [6, 7], "fine": 6, "cpcer": [6, 7], "lower": 6, "judg": 6, "superior": 6, "forc": 6, "align": 6, "obtain": [6, 7], "frame": 6, "level": 6, "classif": 6, "basi": 6, "shallow": 6, "fusion": 6, "end": 6, "e": [6, 7], "g": 6, "la": 6, "rnnt": 6, "transform": [6, 7], "come": 6, "right": 6, "interpret": 6, "belong": 6, "case": 6, "circumst": 6, "coordin": 6, "assign": 7, "illustr": 7, "aishell4": 7, "constrain": 7, "sourc": 7, "addition": 7, "soon": 7, "simpl": 7, "voic": 7, "activ": 7, "detect": 7, "vad": 7, "concaten": 7, "minimum": 7, "permut": 7, "charact": 7, "error": 7, "rate": 7, "calcul": 7, "step": 7, "firstli": 7, "refer": 7, "hypothesi": 7, "chronolog": 7, "order": 7, "secondli": 7, "repeat": 7, "possibl": 7, "lowest": 7, "tthe": 7, "insert": 7, "Ins": 7, "substitut": 7, "delet": 7, "del": 7, "output": 7, "text": 7, "frac": 7, "mathcal": 7, "n_": 7, "100": 7, "usag": 7, "third": 7, "hug": 7, "face": 7, "list": 7, "clearli": 7, "privat": 7, "simul": 7, "thei": 7, "mandatori": 7, "clear": 7, "scheme": 7, "delight": 8, "introduct": 8, "contact": 8}, "objects": {}, "objtypes": {}, "objnames": {}, "titleterms": {"baselin": 0, "overview": [0, 3], "quick": 0, "start": 0, "result": [0, 1], "challeng": [1, 8], "contact": 2, "dataset": 3, "train": [3, 7], "data": 3, "detail": 3, "alimeet": 3, "corpu": 3, "get": 3, "introduct": 4, "call": 4, "particip": 4, "timelin": 4, "aoe": 4, "time": 4, "guidelin": 4, "organ": 5, "rule": 6, "track": 7, "evalu": 7, "speaker": 7, "attribut": 7, "asr": 7, "metric": 7, "sub": 7, "arrang": 7, "i": 7, "fix": 7, "condit": 7, "ii": 7, "open": 7, "asru": 8, "2023": 8, "multi": 8, "channel": 8, "parti": 8, "meet": 8, "transcript": 8, "2": 8, "0": 8, "m2met2": 8, "content": 8}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 57}, "alltitles": {"Baseline": [[0, "baseline"]], "Overview": [[0, "overview"]], "Quick start": [[0, "quick-start"]], "Baseline results": [[0, "baseline-results"]], "Contact": [[2, "contact"]], "Datasets": [[3, "datasets"]], "Overview of training data": [[3, "overview-of-training-data"]], "Detail of AliMeeting corpus": [[3, "detail-of-alimeeting-corpus"]], "Get the data": [[3, "get-the-data"]], "Introduction": [[4, "introduction"]], "Call for participation": [[4, "call-for-participation"]], "Timeline(AOE Time)": [[4, "timeline-aoe-time"]], "Guidelines": [[4, "guidelines"]], "Organizers": [[5, "organizers"]], "Rules": [[6, "rules"]], "Track & Evaluation": [[7, "track-evaluation"]], "Speaker-Attributed ASR": [[7, "speaker-attributed-asr"]], "Evaluation metric": [[7, "evaluation-metric"]], "Sub-track arrangement": [[7, "sub-track-arrangement"]], "Sub-track I (Fixed Training Condition):": [[7, "sub-track-i-fixed-training-condition"]], "Sub-track II (Open Training Condition):": [[7, "sub-track-ii-open-training-condition"]], "ASRU 2023 MULTI-CHANNEL MULTI-PARTY MEETING TRANSCRIPTION CHALLENGE 2.0 (M2MeT2.0)": [[8, "asru-2023-multi-channel-multi-party-meeting-transcription-challenge-2-0-m2met2-0"]], "Contents:": [[8, null]], "Challenge Result": [[1, "challenge-result"]]}, "indexentries": {}})
\ No newline at end of file
diff --git a/docs/m2met2/index.rst b/docs/m2met2/index.rst
index e0e7562..672cb91 100644
--- a/docs/m2met2/index.rst
+++ b/docs/m2met2/index.rst
@@ -18,5 +18,6 @@
./Track_setting_and_evaluation
./Baseline
./Rules
+ ./Challenge_result
./Organizers
./Contact
diff --git a/docs/m2met2_cn/_build/doctrees/environment.pickle b/docs/m2met2_cn/_build/doctrees/environment.pickle
index a65d613..3e1d79d 100644
--- a/docs/m2met2_cn/_build/doctrees/environment.pickle
+++ b/docs/m2met2_cn/_build/doctrees/environment.pickle
Binary files differ
diff --git a/docs/m2met2_cn/_build/doctrees/index.doctree b/docs/m2met2_cn/_build/doctrees/index.doctree
index 43e1b34..abdde10 100644
--- a/docs/m2met2_cn/_build/doctrees/index.doctree
+++ b/docs/m2met2_cn/_build/doctrees/index.doctree
Binary files differ
diff --git "a/docs/m2met2_cn/_build/doctrees/\346\257\224\350\265\233\347\273\223\346\236\234.doctree" "b/docs/m2met2_cn/_build/doctrees/\346\257\224\350\265\233\347\273\223\346\236\234.doctree"
new file mode 100644
index 0000000..bed03e6
--- /dev/null
+++ "b/docs/m2met2_cn/_build/doctrees/\346\257\224\350\265\233\347\273\223\346\236\234.doctree"
Binary files differ
diff --git a/docs/m2met2_cn/_build/html/_sources/index.rst.txt b/docs/m2met2_cn/_build/html/_sources/index.rst.txt
index 3d9f241..15dee0b 100644
--- a/docs/m2met2_cn/_build/html/_sources/index.rst.txt
+++ b/docs/m2met2_cn/_build/html/_sources/index.rst.txt
@@ -18,5 +18,6 @@
./璧涢亾璁剧疆涓庤瘎浼�
./鍩虹嚎
./瑙勫垯
+ ./姣旇禌缁撴灉
./缁勫浼�
./鑱旂郴鏂瑰紡
diff --git "a/docs/m2met2_cn/_build/html/_sources/\346\257\224\350\265\233\347\273\223\346\236\234.md.txt" "b/docs/m2met2_cn/_build/html/_sources/\346\257\224\350\265\233\347\273\223\346\236\234.md.txt"
new file mode 100644
index 0000000..c577718
--- /dev/null
+++ "b/docs/m2met2_cn/_build/html/_sources/\346\257\224\350\265\233\347\273\223\346\236\234.md.txt"
@@ -0,0 +1,14 @@
+# 姣旇禌缁撴灉
+琛ㄤ腑涓烘湰娆$珵璧涚殑鏈�缁堢粨鏋滐紝鍏朵腑Sub-track1浠h〃闄愬畾鏁版嵁瀛愯禌閬擄紝Sub-track2浠h〃闈為檺瀹氭暟鎹瓙璧涢亾銆傝〃涓暟鎹潎涓篶p-CER锛�%锛夈�傜敱浜庢墍鏈夐槦浼嶇殑鎻愪氦鍧囩鍚堥檺瀹氭暟鎹瓙璧涢亾鐨勮姹傦紝琛ㄤ腑鐨勬帓鍚嶄负涓や釜瀛楄禌閬撳悎骞跺悗鐨勬帓鍚嶃��
+| 鎺掑悕 |闃熶紞鍚嶇О | 瀛愯禌閬撲竴 | 瀛愯禌閬撲簩 | 璁烘枃 |
+|------|----------------------|------------|------------|------------------------|
+| 1 | Ximalaya Speech Team | 11.27 | 11.27 | |
+| 2 | 灏忛┈杈� | 18.64 | 18.64 | |
+| 3 | AIzyzx | 22.83 | 22.83 | |
+| 4 | AsrSpeeder | / | 23.51 | |
+| 5 | zyxlhz | 24.82 | 24.82 | |
+| 6 | CMCAI | 26.11 | / | |
+| 7 | Volcspeech | 34.21 | 34.21 | |
+| 8 | 閴村線鐭ユ潵 | 40.14 | 40.14 | |
+| 9 | baseline | 41.55 | 41.55 | |
+| 10 | DAICT | 41.64 | | |
\ No newline at end of file
diff --git a/docs/m2met2_cn/_build/html/genindex.html b/docs/m2met2_cn/_build/html/genindex.html
index 1eee622..ddb6232 100644
--- a/docs/m2met2_cn/_build/html/genindex.html
+++ b/docs/m2met2_cn/_build/html/genindex.html
@@ -92,6 +92,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git a/docs/m2met2_cn/_build/html/index.html b/docs/m2met2_cn/_build/html/index.html
index b7672cf..5df85ab 100644
--- a/docs/m2met2_cn/_build/html/index.html
+++ b/docs/m2met2_cn/_build/html/index.html
@@ -97,6 +97,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
@@ -132,6 +133,7 @@
<li class="toctree-l1"><a class="reference internal" href="%E8%B5%9B%E9%81%93%E8%AE%BE%E7%BD%AE%E4%B8%8E%E8%AF%84%E4%BC%B0.html">璧涢亾璁剧疆涓庤瘎浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E5%9F%BA%E7%BA%BF.html">鍩虹嚎</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git a/docs/m2met2_cn/_build/html/objects.inv b/docs/m2met2_cn/_build/html/objects.inv
index d846652..8466fe7 100644
--- a/docs/m2met2_cn/_build/html/objects.inv
+++ b/docs/m2met2_cn/_build/html/objects.inv
Binary files differ
diff --git a/docs/m2met2_cn/_build/html/search.html b/docs/m2met2_cn/_build/html/search.html
index ca234a0..1020749 100644
--- a/docs/m2met2_cn/_build/html/search.html
+++ b/docs/m2met2_cn/_build/html/search.html
@@ -85,6 +85,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git a/docs/m2met2_cn/_build/html/searchindex.js b/docs/m2met2_cn/_build/html/searchindex.js
index 0976d1d..65f0147 100644
--- a/docs/m2met2_cn/_build/html/searchindex.js
+++ b/docs/m2met2_cn/_build/html/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["index", "\u57fa\u7ebf", "\u6570\u636e\u96c6", "\u7b80\u4ecb", "\u7ec4\u59d4\u4f1a", "\u8054\u7cfb\u65b9\u5f0f", "\u89c4\u5219", "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30"], "filenames": ["index.rst", "\u57fa\u7ebf.md", "\u6570\u636e\u96c6.md", "\u7b80\u4ecb.md", "\u7ec4\u59d4\u4f1a.md", "\u8054\u7cfb\u65b9\u5f0f.md", "\u89c4\u5219.md", "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30.md"], "titles": ["ASRU 2023 \u591a\u901a\u9053\u591a\u65b9\u4f1a\u8bae\u8f6c\u5f55\u6311\u6218 2.0", "\u57fa\u7ebf", "\u6570\u636e\u96c6", "\u7b80\u4ecb", "\u7ec4\u59d4\u4f1a", "\u8054\u7cfb\u65b9\u5f0f", "\u7ade\u8d5b\u89c4\u5219", "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30"], "terms": {"m2met": [0, 1, 3, 5, 7], "asru2023": [0, 3], "m2met2": [0, 3, 5, 7], "funasr": 1, "sa": 1, "asr": [1, 3, 7], "speakerencod": 1, "modelscop": [1, 7], "instal": 1, "run": 1, "sh": 1, "run_m2met_2023_inf": 1, "alimeet": [1, 3, 5, 7], "dataset": 1, "eval_ali_far": 1, "eval_ali_near": 1, "test_ali_far": 1, "test_ali_near": 1, "train_ali_far": 1, "train_ali_near": 1, "test_2023_ali_far": 1, "16": [1, 3], "wav": 1, "scp": 1, "wav_raw": 1, "segment": 1, "utt2spk": 1, "spk2utt": 1, "data": 1, "aishel": [2, 7], "cn": [2, 4, 7], "celeb": [2, 7], "test": [2, 6, 7], "2023": [2, 3, 6, 7], "118": 2, "75": 2, "104": 2, "train": 2, "eval": [2, 6], "10": [2, 3, 7], "212": 2, "15": 2, "30": 2, "456": 2, "25": 2, "13": 2, "55": 2, "42": 2, "27": 2, "34": 2, "76": 2, "20": [2, 3], "textgrid": 2, "id": 2, "openslr": 2, "baselin": 2, "automat": 3, "speech": 3, "recognit": 3, "speaker": 3, "diariz": 3, "rich": 3, "transcript": 3, "evalu": 3, "chime": 3, "comput": 3, "hear": 3, "in": 3, "multisourc": 3, "environ": 3, "misp": 3, "multimod": 3, "inform": 3, "base": 3, "process": 3, "multi": 3, "channel": 3, "parti": 3, "meet": 3, "iassp2022": 3, "asru": 3, "29": 3, "11": 3, "22": 3, "26": 3, "session": 3, "12": 3, "workshop": 3, "challeng": 3, "gmail": [3, 5], "com": [3, 4, 5], "lxie": 4, "nwpu": 4, "edu": 4, "kong": 4, "aik": 4, "lee": 4, "star": 4, "kongaik": 4, "ieee": 4, "org": 4, "zhiji": 4, "yzj": 4, "alibaba": 4, "inc": 4, "sli": 4, "zsl": 4, "yanminqian": 4, "sjtu": 4, "zhuc": 4, "microsoft": 4, "wujian": 4, "ceo": 4, "buhui": 4, "aishelldata": 4, "cpcer": [6, 7], "las": 6, "rnnt": 6, "transform": 6, "aishell4": 7, "vad": 7, "cer": 7, "ins": 7, "sub": 7, "del": 7, "text": 7, "frac": 7, "mathcal": 7, "n_": 7, "total": 7, "time": 7, "100": 7, "hug": 7, "face": 7}, "objects": {}, "objtypes": {}, "objnames": {}, "titleterms": {"asru": 0, "2023": 0, "alimeet": 2, "aoe": 3}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 57}, "alltitles": {"ASRU 2023 \u591a\u901a\u9053\u591a\u65b9\u4f1a\u8bae\u8f6c\u5f55\u6311\u6218 2.0": [[0, "asru-2023-2-0"]], "\u76ee\u5f55:": [[0, null]], "\u57fa\u7ebf": [[1, "id1"]], "\u57fa\u7ebf\u6982\u8ff0": [[1, "id2"]], "\u5feb\u901f\u5f00\u59cb": [[1, "id3"]], "\u57fa\u7ebf\u7ed3\u679c": [[1, "id4"]], "\u6570\u636e\u96c6": [[2, "id1"]], "\u6570\u636e\u96c6\u6982\u8ff0": [[2, "id2"]], "Alimeeting\u6570\u636e\u96c6\u4ecb\u7ecd": [[2, "alimeeting"]], "\u83b7\u53d6\u6570\u636e": [[2, "id3"]], "\u7b80\u4ecb": [[3, "id1"]], "\u7ade\u8d5b\u4ecb\u7ecd": [[3, "id2"]], "\u65f6\u95f4\u5b89\u6392(AOE\u65f6\u95f4)": [[3, "aoe"]], "\u7ade\u8d5b\u62a5\u540d": [[3, "id3"]], "\u7ec4\u59d4\u4f1a": [[4, "id1"]], "\u8054\u7cfb\u65b9\u5f0f": [[5, "id1"]], "\u7ade\u8d5b\u89c4\u5219": [[6, "id1"]], "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30": [[7, "id1"]], "\u8bf4\u8bdd\u4eba\u76f8\u5173\u7684\u8bed\u97f3\u8bc6\u522b": [[7, "id2"]], "\u8bc4\u4f30\u65b9\u6cd5": [[7, "id3"]], "\u5b50\u8d5b\u9053\u8bbe\u7f6e": [[7, "id4"]], "\u5b50\u8d5b\u9053\u4e00 (\u9650\u5b9a\u8bad\u7ec3\u6570\u636e):": [[7, "id5"]], "\u5b50\u8d5b\u9053\u4e8c (\u5f00\u653e\u8bad\u7ec3\u6570\u636e):": [[7, "id6"]]}, "indexentries": {}})
\ No newline at end of file
+Search.setIndex({"docnames": ["index", "\u57fa\u7ebf", "\u6570\u636e\u96c6", "\u6bd4\u8d5b\u7ed3\u679c", "\u7b80\u4ecb", "\u7ec4\u59d4\u4f1a", "\u8054\u7cfb\u65b9\u5f0f", "\u89c4\u5219", "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30"], "filenames": ["index.rst", "\u57fa\u7ebf.md", "\u6570\u636e\u96c6.md", "\u6bd4\u8d5b\u7ed3\u679c.md", "\u7b80\u4ecb.md", "\u7ec4\u59d4\u4f1a.md", "\u8054\u7cfb\u65b9\u5f0f.md", "\u89c4\u5219.md", "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30.md"], "titles": ["ASRU 2023 \u591a\u901a\u9053\u591a\u65b9\u4f1a\u8bae\u8f6c\u5f55\u6311\u6218 2.0", "\u57fa\u7ebf", "\u6570\u636e\u96c6", "\u6bd4\u8d5b\u7ed3\u679c", "\u7b80\u4ecb", "\u7ec4\u59d4\u4f1a", "\u8054\u7cfb\u65b9\u5f0f", "\u7ade\u8d5b\u89c4\u5219", "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30"], "terms": {"m2met": [0, 1, 4, 6, 8], "asru2023": [0, 4], "m2met2": [0, 4, 6, 8], "funasr": 1, "sa": 1, "asr": [1, 4, 8], "speakerencod": 1, "modelscop": [1, 8], "instal": 1, "run": 1, "sh": 1, "run_m2met_2023_inf": 1, "alimeet": [1, 4, 6, 8], "dataset": 1, "eval_ali_far": 1, "eval_ali_near": 1, "test_ali_far": 1, "test_ali_near": 1, "train_ali_far": 1, "train_ali_near": 1, "test_2023_ali_far": 1, "16": [1, 4], "wav": 1, "scp": 1, "wav_raw": 1, "segment": 1, "utt2spk": 1, "spk2utt": 1, "data": 1, "aishel": [2, 8], "cn": [2, 5, 8], "celeb": [2, 8], "test": [2, 7, 8], "2023": [2, 4, 7, 8], "118": 2, "75": 2, "104": 2, "train": 2, "eval": [2, 7], "10": [2, 3, 4, 8], "212": 2, "15": 2, "30": 2, "456": 2, "25": 2, "13": 2, "55": [2, 3], "42": 2, "27": [2, 3], "34": [2, 3], "76": 2, "20": [2, 4], "textgrid": 2, "id": 2, "openslr": 2, "baselin": [2, 3], "sub": [3, 8], "track1": 3, "track2": 3, "cp": 3, "cer": [3, 8], "ximalaya": 3, "speech": [3, 4], "team": 3, "11": [3, 4], "18": 3, "64": 3, "aizyzx": 3, "22": [3, 4], "83": 3, "asrspeed": 3, "23": 3, "51": 3, "zyxlhz": 3, "24": 3, "82": 3, "cmcai": 3, "26": [3, 4], "volcspeech": 3, "21": 3, "40": 3, "14": 3, "41": 3, "daict": 3, "automat": 4, "recognit": 4, "speaker": 4, "diariz": 4, "rich": 4, "transcript": 4, "evalu": 4, "chime": 4, "comput": 4, "hear": 4, "in": 4, "multisourc": 4, "environ": 4, "misp": 4, "multimod": 4, "inform": 4, "base": 4, "process": 4, "multi": 4, "channel": 4, "parti": 4, "meet": 4, "iassp2022": 4, "asru": 4, "29": 4, "session": 4, "12": 4, "workshop": 4, "challeng": 4, "gmail": [4, 6], "com": [4, 5, 6], "lxie": 5, "nwpu": 5, "edu": 5, "kong": 5, "aik": 5, "lee": 5, "star": 5, "kongaik": 5, "ieee": 5, "org": 5, "zhiji": 5, "yzj": 5, "alibaba": 5, "inc": 5, "sli": 5, "zsl": 5, "yanminqian": 5, "sjtu": 5, "zhuc": 5, "microsoft": 5, "wujian": 5, "ceo": 5, "buhui": 5, "aishelldata": 5, "cpcer": [7, 8], "las": 7, "rnnt": 7, "transform": 7, "aishell4": 8, "vad": 8, "ins": 8, "del": 8, "text": 8, "frac": 8, "mathcal": 8, "n_": 8, "total": 8, "time": 8, "100": 8, "hug": 8, "face": 8}, "objects": {}, "objtypes": {}, "objnames": {}, "titleterms": {"asru": 0, "2023": 0, "alimeet": 2, "aoe": 4}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 57}, "alltitles": {"ASRU 2023 \u591a\u901a\u9053\u591a\u65b9\u4f1a\u8bae\u8f6c\u5f55\u6311\u6218 2.0": [[0, "asru-2023-2-0"]], "\u76ee\u5f55:": [[0, null]], "\u57fa\u7ebf": [[1, "id1"]], "\u57fa\u7ebf\u6982\u8ff0": [[1, "id2"]], "\u5feb\u901f\u5f00\u59cb": [[1, "id3"]], "\u57fa\u7ebf\u7ed3\u679c": [[1, "id4"]], "\u6570\u636e\u96c6": [[2, "id1"]], "\u6570\u636e\u96c6\u6982\u8ff0": [[2, "id2"]], "Alimeeting\u6570\u636e\u96c6\u4ecb\u7ecd": [[2, "alimeeting"]], "\u83b7\u53d6\u6570\u636e": [[2, "id3"]], "\u6bd4\u8d5b\u7ed3\u679c": [[3, "id1"]], "\u7b80\u4ecb": [[4, "id1"]], "\u7ade\u8d5b\u4ecb\u7ecd": [[4, "id2"]], "\u65f6\u95f4\u5b89\u6392(AOE\u65f6\u95f4)": [[4, "aoe"]], "\u7ade\u8d5b\u62a5\u540d": [[4, "id3"]], "\u7ec4\u59d4\u4f1a": [[5, "id1"]], "\u8054\u7cfb\u65b9\u5f0f": [[6, "id1"]], "\u7ade\u8d5b\u89c4\u5219": [[7, "id1"]], "\u8d5b\u9053\u8bbe\u7f6e\u4e0e\u8bc4\u4f30": [[8, "id1"]], "\u8bf4\u8bdd\u4eba\u76f8\u5173\u7684\u8bed\u97f3\u8bc6\u522b": [[8, "id2"]], "\u8bc4\u4f30\u65b9\u6cd5": [[8, "id3"]], "\u5b50\u8d5b\u9053\u8bbe\u7f6e": [[8, "id4"]], "\u5b50\u8d5b\u9053\u4e00 (\u9650\u5b9a\u8bad\u7ec3\u6570\u636e):": [[8, "id5"]], "\u5b50\u8d5b\u9053\u4e8c (\u5f00\u653e\u8bad\u7ec3\u6570\u636e):": [[8, "id6"]]}, "indexentries": {}})
\ No newline at end of file
diff --git "a/docs/m2met2_cn/_build/html/\345\237\272\347\272\277.html" "b/docs/m2met2_cn/_build/html/\345\237\272\347\272\277.html"
index 9161a64..5cb3797 100644
--- "a/docs/m2met2_cn/_build/html/\345\237\272\347\272\277.html"
+++ "b/docs/m2met2_cn/_build/html/\345\237\272\347\272\277.html"
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git "a/docs/m2met2_cn/_build/html/\346\225\260\346\215\256\351\233\206.html" "b/docs/m2met2_cn/_build/html/\346\225\260\346\215\256\351\233\206.html"
index 016c58f..4198b5a 100644
--- "a/docs/m2met2_cn/_build/html/\346\225\260\346\215\256\351\233\206.html"
+++ "b/docs/m2met2_cn/_build/html/\346\225\260\346\215\256\351\233\206.html"
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git "a/docs/m2met2_cn/_build/html/\346\257\224\350\265\233\347\273\223\346\236\234.html" "b/docs/m2met2_cn/_build/html/\346\257\224\350\265\233\347\273\223\346\236\234.html"
new file mode 100644
index 0000000..d22579c
--- /dev/null
+++ "b/docs/m2met2_cn/_build/html/\346\257\224\350\265\233\347\273\223\346\236\234.html"
@@ -0,0 +1,248 @@
+
+<!DOCTYPE html>
+
+<html lang="zh-CN">
+ <head>
+ <meta charset="utf-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
+
+
+ <!-- Licensed under the Apache 2.0 License -->
+ <link rel="stylesheet" type="text/css" href="_static/fonts/open-sans/stylesheet.css" />
+ <!-- Licensed under the SIL Open Font License -->
+ <link rel="stylesheet" type="text/css" href="_static/fonts/source-serif-pro/source-serif-pro.css" />
+ <link rel="stylesheet" type="text/css" href="_static/css/bootstrap.min.css" />
+ <link rel="stylesheet" type="text/css" href="_static/css/bootstrap-theme.min.css" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+
+ <title>姣旇禌缁撴灉 — 澶氶�氶亾澶氭柟浼氳杞綍鎸戞垬2.0</title>
+ <link rel="stylesheet" type="text/css" href="_static/pygments.css" />
+ <link rel="stylesheet" type="text/css" href="_static/guzzle.css" />
+ <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
+ <script src="_static/jquery.js"></script>
+ <script src="_static/underscore.js"></script>
+ <script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
+ <script src="_static/doctools.js"></script>
+ <script src="_static/sphinx_highlight.js"></script>
+ <script src="_static/translations.js"></script>
+ <link rel="index" title="绱㈠紩" href="genindex.html" />
+ <link rel="search" title="鎼滅储" href="search.html" />
+ <link rel="next" title="缁勫浼�" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" />
+ <link rel="prev" title="绔炶禌瑙勫垯" href="%E8%A7%84%E5%88%99.html" />
+
+
+
+ </head><body>
+ <div class="related" role="navigation" aria-label="related navigation">
+ <h3>瀵艰埅</h3>
+ <ul>
+ <li class="right" style="margin-right: 10px">
+ <a href="genindex.html" title="鎬荤储寮�"
+ accesskey="I">绱㈠紩</a></li>
+ <li class="right" >
+ <a href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" title="缁勫浼�"
+ accesskey="N">涓嬩竴椤�</a> |</li>
+ <li class="right" >
+ <a href="%E8%A7%84%E5%88%99.html" title="绔炶禌瑙勫垯"
+ accesskey="P">涓婁竴椤�</a> |</li>
+ <li class="nav-item nav-item-0"><a href="index.html">澶氶�氶亾澶氭柟浼氳杞綍鎸戞垬2.0</a> »</li>
+ <li class="nav-item nav-item-this"><a href="">姣旇禌缁撴灉</a></li>
+ </ul>
+ </div>
+ <div class="container-wrapper">
+
+ <div id="mobile-toggle">
+ <a href="#"><span class="glyphicon glyphicon-align-justify" aria-hidden="true"></span></a>
+ </div>
+ <div id="left-column">
+ <div class="sphinxsidebar"><a href="
+ index.html" class="text-logo">澶氶�氶亾澶氭柟浼氳杞綍鎸戞垬2.0</a>
+<div class="sidebar-block">
+ <div class="sidebar-wrapper">
+ <div id="main-search">
+ <form class="form-inline" action="search.html" method="GET" role="form">
+ <div class="input-group">
+ <input name="q" type="text" class="form-control" placeholder="Search...">
+ </div>
+ <input type="hidden" name="check_keywords" value="yes" />
+ <input type="hidden" name="area" value="default" />
+ </form>
+ </div>
+ </div>
+</div>
+<div class="sidebar-block">
+ <div class="sidebar-toc">
+
+
+ <p class="caption" role="heading"><span class="caption-text">鐩綍:</span></p>
+<ul class="current">
+<li class="toctree-l1"><a class="reference internal" href="%E7%AE%80%E4%BB%8B.html">绠�浠�</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="%E7%AE%80%E4%BB%8B.html#id2">绔炶禌浠嬬粛</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E7%AE%80%E4%BB%8B.html#aoe">鏃堕棿瀹夋帓(AOE鏃堕棿)</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E7%AE%80%E4%BB%8B.html#id3">绔炶禌鎶ュ悕</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%95%B0%E6%8D%AE%E9%9B%86.html">鏁版嵁闆�</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="%E6%95%B0%E6%8D%AE%E9%9B%86.html#id2">鏁版嵁闆嗘杩�</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E6%95%B0%E6%8D%AE%E9%9B%86.html#alimeeting">Alimeeting鏁版嵁闆嗕粙缁�</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E6%95%B0%E6%8D%AE%E9%9B%86.html#id3">鑾峰彇鏁版嵁</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="%E8%B5%9B%E9%81%93%E8%AE%BE%E7%BD%AE%E4%B8%8E%E8%AF%84%E4%BC%B0.html">璧涢亾璁剧疆涓庤瘎浼�</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="%E8%B5%9B%E9%81%93%E8%AE%BE%E7%BD%AE%E4%B8%8E%E8%AF%84%E4%BC%B0.html#id2">璇磋瘽浜虹浉鍏崇殑璇煶璇嗗埆</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E8%B5%9B%E9%81%93%E8%AE%BE%E7%BD%AE%E4%B8%8E%E8%AF%84%E4%BC%B0.html#id3">璇勪及鏂规硶</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E8%B5%9B%E9%81%93%E8%AE%BE%E7%BD%AE%E4%B8%8E%E8%AF%84%E4%BC%B0.html#id4">瀛愯禌閬撹缃�</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="%E5%9F%BA%E7%BA%BF.html">鍩虹嚎</a><ul>
+<li class="toctree-l2"><a class="reference internal" href="%E5%9F%BA%E7%BA%BF.html#id2">鍩虹嚎姒傝堪</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E5%9F%BA%E7%BA%BF.html#id3">蹇�熷紑濮�</a></li>
+<li class="toctree-l2"><a class="reference internal" href="%E5%9F%BA%E7%BA%BF.html#id4">鍩虹嚎缁撴灉</a></li>
+</ul>
+</li>
+<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1 current"><a class="current reference internal" href="#">姣旇禌缁撴灉</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
+</ul>
+
+
+ </div>
+</div>
+
+ </div>
+ </div>
+ <div id="right-column">
+
+ <div role="navigation" aria-label="breadcrumbs navigation">
+ <ol class="breadcrumb">
+ <li><a href="index.html">Docs</a></li>
+
+ <li>姣旇禌缁撴灉</li>
+ </ol>
+ </div>
+
+ <div class="document clearer body">
+
+ <section id="id1">
+<h1>姣旇禌缁撴灉<a class="headerlink" href="#id1" title="姝ゆ爣棰樼殑姘镐箙閾炬帴">露</a></h1>
+<p>琛ㄤ腑涓烘湰娆$珵璧涚殑鏈�缁堢粨鏋滐紝鍏朵腑Sub-track1浠h〃闄愬畾鏁版嵁瀛愯禌閬擄紝Sub-track2浠h〃闈為檺瀹氭暟鎹瓙璧涢亾銆傝〃涓暟鎹潎涓篶p-CER锛�%锛夈�傜敱浜庢墍鏈夐槦浼嶇殑鎻愪氦鍧囩鍚堥檺瀹氭暟鎹瓙璧涢亾鐨勮姹傦紝琛ㄤ腑鐨勬帓鍚嶄负涓や釜瀛楄禌閬撳悎骞跺悗鐨勬帓鍚嶃��</p>
+<table class="docutils align-default">
+<thead>
+<tr class="row-odd"><th class="head"><p>鎺掑悕 聽 聽</p></th>
+<th class="head"><p>闃熶紞鍚嶇О 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽 聽</p></th>
+<th class="head"><p>瀛愯禌閬撲竴 聽 聽</p></th>
+<th class="head"><p>瀛愯禌閬撲簩 聽 聽</p></th>
+<th class="head"><p>璁烘枃 聽 聽</p></th>
+</tr>
+</thead>
+<tbody>
+<tr class="row-even"><td><p>1</p></td>
+<td><p>Ximalaya Speech Team</p></td>
+<td><p>11.27</p></td>
+<td><p>11.27</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>2</p></td>
+<td><p>灏忛┈杈�</p></td>
+<td><p>18.64</p></td>
+<td><p>18.64</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>3</p></td>
+<td><p>AIzyzx</p></td>
+<td><p>22.83</p></td>
+<td><p>22.83</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>4</p></td>
+<td><p>AsrSpeeder</p></td>
+<td><p>/</p></td>
+<td><p>23.51</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>5</p></td>
+<td><p>zyxlhz</p></td>
+<td><p>24.82</p></td>
+<td><p>24.82</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>6</p></td>
+<td><p>CMCAI</p></td>
+<td><p>26.11</p></td>
+<td><p>/</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>7</p></td>
+<td><p>Volcspeech</p></td>
+<td><p>34.21</p></td>
+<td><p>34.21</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>8</p></td>
+<td><p>閴村線鐭ユ潵</p></td>
+<td><p>40.14</p></td>
+<td><p>40.14</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-even"><td><p>9</p></td>
+<td><p>baseline</p></td>
+<td><p>41.55</p></td>
+<td><p>41.55</p></td>
+<td><p></p></td>
+</tr>
+<tr class="row-odd"><td><p>10</p></td>
+<td><p>DAICT</p></td>
+<td><p>41.64</p></td>
+<td><p></p></td>
+<td><p></p></td>
+</tr>
+</tbody>
+</table>
+</section>
+
+
+ </div>
+
+ <div class="footer-relations">
+
+ <div class="pull-left">
+ <a class="btn btn-default" href="%E8%A7%84%E5%88%99.html" title="涓婁竴绔� (use the left arrow)">绔炶禌瑙勫垯</a>
+ </div>
+
+ <div class="pull-right">
+ <a class="btn btn-default" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" title="涓嬩竴绔� (use the right arrow)">缁勫浼�</a>
+ </div>
+ </div>
+ <div class="clearer"></div>
+
+ </div>
+ <div class="clearfix"></div>
+ </div>
+ <div class="related" role="navigation" aria-label="related navigation">
+ <h3>瀵艰埅</h3>
+ <ul>
+ <li class="right" style="margin-right: 10px">
+ <a href="genindex.html" title="鎬荤储寮�"
+ >绱㈠紩</a></li>
+ <li class="right" >
+ <a href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" title="缁勫浼�"
+ >涓嬩竴椤�</a> |</li>
+ <li class="right" >
+ <a href="%E8%A7%84%E5%88%99.html" title="绔炶禌瑙勫垯"
+ >涓婁竴椤�</a> |</li>
+ <li class="nav-item nav-item-0"><a href="index.html">澶氶�氶亾澶氭柟浼氳杞綍鎸戞垬2.0</a> »</li>
+ <li class="nav-item nav-item-this"><a href="">姣旇禌缁撴灉</a></li>
+ </ul>
+ </div>
+<script type="text/javascript">
+ $("#mobile-toggle a").click(function () {
+ $("#left-column").toggle();
+ });
+</script>
+<script type="text/javascript" src="_static/js/bootstrap.js"></script>
+ <div class="footer">
+ © Copyright 2023, Speech Lab, Alibaba Group; ASLP Group, Northwestern Polytechnical University. Created using <a href="http://sphinx.pocoo.org/">Sphinx</a>.
+ </div>
+ </body>
+</html>
\ No newline at end of file
diff --git "a/docs/m2met2_cn/_build/html/\347\256\200\344\273\213.html" "b/docs/m2met2_cn/_build/html/\347\256\200\344\273\213.html"
index 05f8847..4628a12 100644
--- "a/docs/m2met2_cn/_build/html/\347\256\200\344\273\213.html"
+++ "b/docs/m2met2_cn/_build/html/\347\256\200\344\273\213.html"
@@ -102,6 +102,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git "a/docs/m2met2_cn/_build/html/\347\273\204\345\247\224\344\274\232.html" "b/docs/m2met2_cn/_build/html/\347\273\204\345\247\224\344\274\232.html"
index e39465f..a280555 100644
--- "a/docs/m2met2_cn/_build/html/\347\273\204\345\247\224\344\274\232.html"
+++ "b/docs/m2met2_cn/_build/html/\347\273\204\345\247\224\344\274\232.html"
@@ -28,7 +28,7 @@
<link rel="index" title="绱㈠紩" href="genindex.html" />
<link rel="search" title="鎼滅储" href="search.html" />
<link rel="next" title="鑱旂郴鏂瑰紡" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html" />
- <link rel="prev" title="绔炶禌瑙勫垯" href="%E8%A7%84%E5%88%99.html" />
+ <link rel="prev" title="姣旇禌缁撴灉" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" />
@@ -43,7 +43,7 @@
<a href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html" title="鑱旂郴鏂瑰紡"
accesskey="N">涓嬩竴椤�</a> |</li>
<li class="right" >
- <a href="%E8%A7%84%E5%88%99.html" title="绔炶禌瑙勫垯"
+ <a href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" title="姣旇禌缁撴灉"
accesskey="P">涓婁竴椤�</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">澶氶�氶亾澶氭柟浼氳杞綍鎸戞垬2.0</a> »</li>
<li class="nav-item nav-item-this"><a href="">缁勫浼�</a></li>
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
@@ -164,7 +165,7 @@
<div class="footer-relations">
<div class="pull-left">
- <a class="btn btn-default" href="%E8%A7%84%E5%88%99.html" title="涓婁竴绔� (use the left arrow)">绔炶禌瑙勫垯</a>
+ <a class="btn btn-default" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" title="涓婁竴绔� (use the left arrow)">姣旇禌缁撴灉</a>
</div>
<div class="pull-right">
@@ -186,7 +187,7 @@
<a href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html" title="鑱旂郴鏂瑰紡"
>涓嬩竴椤�</a> |</li>
<li class="right" >
- <a href="%E8%A7%84%E5%88%99.html" title="绔炶禌瑙勫垯"
+ <a href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" title="姣旇禌缁撴灉"
>涓婁竴椤�</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">澶氶�氶亾澶氭柟浼氳杞綍鎸戞垬2.0</a> »</li>
<li class="nav-item nav-item-this"><a href="">缁勫浼�</a></li>
diff --git "a/docs/m2met2_cn/_build/html/\350\201\224\347\263\273\346\226\271\345\274\217.html" "b/docs/m2met2_cn/_build/html/\350\201\224\347\263\273\346\226\271\345\274\217.html"
index fc060e8..095df78 100644
--- "a/docs/m2met2_cn/_build/html/\350\201\224\347\263\273\346\226\271\345\274\217.html"
+++ "b/docs/m2met2_cn/_build/html/\350\201\224\347\263\273\346\226\271\345\274\217.html"
@@ -97,6 +97,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git "a/docs/m2met2_cn/_build/html/\350\247\204\345\210\231.html" "b/docs/m2met2_cn/_build/html/\350\247\204\345\210\231.html"
index 7d54533..281cca7 100644
--- "a/docs/m2met2_cn/_build/html/\350\247\204\345\210\231.html"
+++ "b/docs/m2met2_cn/_build/html/\350\247\204\345\210\231.html"
@@ -27,7 +27,7 @@
<script src="_static/translations.js"></script>
<link rel="index" title="绱㈠紩" href="genindex.html" />
<link rel="search" title="鎼滅储" href="search.html" />
- <link rel="next" title="缁勫浼�" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" />
+ <link rel="next" title="姣旇禌缁撴灉" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" />
<link rel="prev" title="鍩虹嚎" href="%E5%9F%BA%E7%BA%BF.html" />
@@ -40,7 +40,7 @@
<a href="genindex.html" title="鎬荤储寮�"
accesskey="I">绱㈠紩</a></li>
<li class="right" >
- <a href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" title="缁勫浼�"
+ <a href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" title="姣旇禌缁撴灉"
accesskey="N">涓嬩竴椤�</a> |</li>
<li class="right" >
<a href="%E5%9F%BA%E7%BA%BF.html" title="鍩虹嚎"
@@ -101,6 +101,7 @@
</ul>
</li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
@@ -146,7 +147,7 @@
</div>
<div class="pull-right">
- <a class="btn btn-default" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" title="涓嬩竴绔� (use the right arrow)">缁勫浼�</a>
+ <a class="btn btn-default" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" title="涓嬩竴绔� (use the right arrow)">姣旇禌缁撴灉</a>
</div>
</div>
<div class="clearer"></div>
@@ -161,7 +162,7 @@
<a href="genindex.html" title="鎬荤储寮�"
>绱㈠紩</a></li>
<li class="right" >
- <a href="%E7%BB%84%E5%A7%94%E4%BC%9A.html" title="缁勫浼�"
+ <a href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html" title="姣旇禌缁撴灉"
>涓嬩竴椤�</a> |</li>
<li class="right" >
<a href="%E5%9F%BA%E7%BA%BF.html" title="鍩虹嚎"
diff --git "a/docs/m2met2_cn/_build/html/\350\265\233\351\201\223\350\256\276\347\275\256\344\270\216\350\257\204\344\274\260.html" "b/docs/m2met2_cn/_build/html/\350\265\233\351\201\223\350\256\276\347\275\256\344\270\216\350\257\204\344\274\260.html"
index c9a15f9..ddc419f 100644
--- "a/docs/m2met2_cn/_build/html/\350\265\233\351\201\223\350\256\276\347\275\256\344\270\216\350\257\204\344\274\260.html"
+++ "b/docs/m2met2_cn/_build/html/\350\265\233\351\201\223\350\256\276\347\275\256\344\270\216\350\257\204\344\274\260.html"
@@ -102,6 +102,7 @@
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="%E8%A7%84%E5%88%99.html">绔炶禌瑙勫垯</a></li>
+<li class="toctree-l1"><a class="reference internal" href="%E6%AF%94%E8%B5%9B%E7%BB%93%E6%9E%9C.html">姣旇禌缁撴灉</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E7%BB%84%E5%A7%94%E4%BC%9A.html">缁勫浼�</a></li>
<li class="toctree-l1"><a class="reference internal" href="%E8%81%94%E7%B3%BB%E6%96%B9%E5%BC%8F.html">鑱旂郴鏂瑰紡</a></li>
</ul>
diff --git a/docs/m2met2_cn/index.rst b/docs/m2met2_cn/index.rst
index 3d9f241..15dee0b 100644
--- a/docs/m2met2_cn/index.rst
+++ b/docs/m2met2_cn/index.rst
@@ -18,5 +18,6 @@
./璧涢亾璁剧疆涓庤瘎浼�
./鍩虹嚎
./瑙勫垯
+ ./姣旇禌缁撴灉
./缁勫浼�
./鑱旂郴鏂瑰紡
diff --git "a/docs/m2met2_cn/\346\257\224\350\265\233\347\273\223\346\236\234.md" "b/docs/m2met2_cn/\346\257\224\350\265\233\347\273\223\346\236\234.md"
new file mode 100644
index 0000000..c577718
--- /dev/null
+++ "b/docs/m2met2_cn/\346\257\224\350\265\233\347\273\223\346\236\234.md"
@@ -0,0 +1,14 @@
+# 姣旇禌缁撴灉
+琛ㄤ腑涓烘湰娆$珵璧涚殑鏈�缁堢粨鏋滐紝鍏朵腑Sub-track1浠h〃闄愬畾鏁版嵁瀛愯禌閬擄紝Sub-track2浠h〃闈為檺瀹氭暟鎹瓙璧涢亾銆傝〃涓暟鎹潎涓篶p-CER锛�%锛夈�傜敱浜庢墍鏈夐槦浼嶇殑鎻愪氦鍧囩鍚堥檺瀹氭暟鎹瓙璧涢亾鐨勮姹傦紝琛ㄤ腑鐨勬帓鍚嶄负涓や釜瀛楄禌閬撳悎骞跺悗鐨勬帓鍚嶃��
+| 鎺掑悕 |闃熶紞鍚嶇О | 瀛愯禌閬撲竴 | 瀛愯禌閬撲簩 | 璁烘枃 |
+|------|----------------------|------------|------------|------------------------|
+| 1 | Ximalaya Speech Team | 11.27 | 11.27 | |
+| 2 | 灏忛┈杈� | 18.64 | 18.64 | |
+| 3 | AIzyzx | 22.83 | 22.83 | |
+| 4 | AsrSpeeder | / | 23.51 | |
+| 5 | zyxlhz | 24.82 | 24.82 | |
+| 6 | CMCAI | 26.11 | / | |
+| 7 | Volcspeech | 34.21 | 34.21 | |
+| 8 | 閴村線鐭ユ潵 | 40.14 | 40.14 | |
+| 9 | baseline | 41.55 | 41.55 | |
+| 10 | DAICT | 41.64 | | |
\ No newline at end of file
diff --git a/docs/runtime/img.png b/docs/runtime/img.png
new file mode 100644
index 0000000..84e2efe
--- /dev/null
+++ b/docs/runtime/img.png
Binary files differ
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index 59f9936..a1f27a3 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -6,7 +6,7 @@
unified_model_training: true
default_chunk_size: 16
jitter_range: 4
- left_chunk_size: 0
+ left_chunk_size: 1
embed_vgg_like: false
subsampling_factor: 4
linear_units: 2048
@@ -51,7 +51,7 @@
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 200
+max_epoch: 120
val_scheduler_criterion:
- valid
- loss
diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa-asr/README.md
deleted file mode 100644
index 2ef6bbe..0000000
--- a/egs/alimeeting/sa-asr/README.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# Get Started
-Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
-To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
-There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.
-Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
-```shell
-dataset
-|鈥斺�� Eval_Ali_far
-|鈥斺�� Eval_Ali_near
-|鈥斺�� Test_Ali_far
-|鈥斺�� Test_Ali_near
-|鈥斺�� Train_Ali_far
-|鈥斺�� Train_Ali_near
-```
-There are 16 stages in `run.sh`:
-```shell
-stage 1 - 5: Data preparation and processing.
-stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
-stage 7 - 9: Language model training (Optional).
-stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
-stage 12: SA-ASR training.
-stage 13 - 16: Inference and evaluation.
-```
-Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.
-```shell
-data/Test_2023_Ali_far
-|鈥斺�� wav.scp
-|鈥斺�� wav_raw.scp
-|鈥斺�� segments
-|鈥斺�� utt2spk
-|鈥斺�� spk2utt
-```
-There are 4 stages in `run_m2met_2023_infer.sh`:
-```shell
-stage 1: Data preparation and processing.
-stage 2: Generate speaker profiles for inference.
-stage 3: Inference.
-stage 4: Generation of SA-ASR results required for final submission.
-```
-
-The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary).
-After generate stats of AliMeeting corpus(stage 10 in `run.sh`), you can set the `infer_with_pretrained_model=true` in `run.sh` to infer with our official baseline model released on ModelScope without training.
-
-# Format of Final Submission
-Finally, you need to submit a file called `text_spk_merge` with the following format:
-```shell
-Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
-Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
-...
-```
-Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
-# Baseline Results
-The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.
-<table>
- <tr >
- <td rowspan="2"></td>
- <td colspan="2">SI-CER(%)</td>
- <td colspan="2">cpCER(%)</td>
- </tr>
- <tr>
- <td>Eval</td>
- <td>Test</td>
- <td>Eval</td>
- <td>Test</td>
- </tr>
- <tr>
- <td>oracle profile</td>
- <td>32.05</td>
- <td>32.70</td>
- <td>47.40</td>
- <td>52.57</td>
- </tr>
- <tr>
- <td>cluster profile</td>
- <td>32.05</td>
- <td>32.70</td>
- <td>53.76</td>
- <td>55.95</td>
- </tr>
-</table>
-
-# Reference
-N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413鈥�4417.
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
deleted file mode 100755
index 30401b9..0000000
--- a/egs/alimeeting/sa-asr/asr_local.sh
+++ /dev/null
@@ -1,1483 +0,0 @@
-#!/usr/bin/env bash
-
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-log() {
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-min() {
- local a b
- a=$1
- for b in "$@"; do
- if [ "${b}" -le "${a}" ]; then
- a="${b}"
- fi
- done
- echo "${a}"
-}
-SECONDS=0
-
-# General configuration
-stage=1 # Processes starts from the specified stage.
-stop_stage=10000 # Processes is stopped at the specified stage.
-skip_data_prep=false # Skip data preparation stages.
-skip_train=false # Skip training stages.
-skip_eval=false # Skip decoding and evaluation stages.
-skip_upload=true # Skip packing and uploading stages.
-ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
-num_nodes=1 # The number of nodes.
-nj=16 # The number of parallel jobs.
-inference_nj=16 # The number of parallel jobs in decoding.
-gpu_inference=false # Whether to perform gpu decoding.
-njob_infer=4
-dumpdir=dump2 # Directory to dump features.
-expdir=exp # Directory to save experiments.
-python=python3 # Specify python to execute espnet commands.
-device=0
-
-# Data preparation related
-local_data_opts= # The options given to local/data.sh.
-
-# Speed perturbation related
-speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
-
-# Feature extraction related
-feats_type=raw # Feature type (raw or fbank_pitch).
-audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
-fs=16000 # Sampling rate.
-min_wav_duration=0.1 # Minimum duration in second.
-max_wav_duration=20 # Maximum duration in second.
-
-# Tokenization related
-token_type=bpe # Tokenization type (char or bpe).
-nbpe=30 # The number of BPE vocabulary.
-bpemode=unigram # Mode of BPE (unigram or bpe).
-oov="<unk>" # Out of vocabulary symbol.
-blank="<blank>" # CTC blank symbol
-sos_eos="<sos/eos>" # sos and eos symbole
-bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
-bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
-bpe_char_cover=1.0 # character coverage when modeling BPE
-
-# Language model related
-use_lm=true # Use language model for ASR decoding.
-lm_tag= # Suffix to the result dir for language model training.
-lm_exp= # Specify the direcotry path for LM experiment.
- # If this option is specified, lm_tag is ignored.
-lm_stats_dir= # Specify the direcotry path for LM statistics.
-lm_config= # Config for language model training.
-lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
- # Note that it will overwrite args in lm config.
-use_word_lm=false # Whether to use word language model.
-num_splits_lm=1 # Number of splitting for lm corpus.
-# shellcheck disable=SC2034
-word_vocab_size=10000 # Size of word vocabulary.
-
-# ASR model related
-asr_tag= # Suffix to the result dir for asr model training.
-asr_exp= # Specify the direcotry path for ASR experiment.
- # If this option is specified, asr_tag is ignored.
-sa_asr_exp=
-asr_stats_dir= # Specify the direcotry path for ASR statistics.
-asr_config= # Config for asr model training.
-sa_asr_config=
-asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
- # Note that it will overwrite args in asr config.
-feats_normalize=global_mvn # Normalizaton layer type.
-num_splits_asr=1 # Number of splitting for lm corpus.
-
-# Decoding related
-inference_tag= # Suffix to the result dir for decoding.
-inference_config= # Config for decoding.
-inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
- # Note that it will overwrite args in inference config.
-sa_asr_inference_tag=
-sa_asr_inference_args=
-
-inference_lm=valid.loss.ave.pb # Language modle path for decoding.
-inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
- # e.g.
- # inference_asr_model=train.loss.best.pth
- # inference_asr_model=3epoch.pth
- # inference_asr_model=valid.acc.best.pth
- # inference_asr_model=valid.loss.ave.pth
-inference_sa_asr_model=valid.acc_spk.ave.pb
-infer_with_pretrained_model=false # Use pretrained model for decoding
-download_sa_asr_model= # Download the SA-ASR model from ModelScope and use it for decoding.
-# [Task dependent] Set the datadir name created by local/data.sh
-train_set= # Name of training set.
-valid_set= # Name of validation set used for monitoring/tuning network training.
-test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
-bpe_train_text= # Text file path of bpe training set.
-lm_train_text= # Text file path of language model training set.
-lm_dev_text= # Text file path of language model development set.
-lm_test_text= # Text file path of language model evaluation set.
-nlsyms_txt=none # Non-linguistic symbol list if existing.
-cleaner=none # Text cleaner.
-g2p=none # g2p method (needed if token_type=phn).
-lang=zh # The language type of corpus.
-score_opts= # The options given to sclite scoring
-local_score_opts= # The options given to local/score.sh.
-
-
-help_message=$(cat << EOF
-Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
-
-Options:
- # General configuration
- --stage # Processes starts from the specified stage (default="${stage}").
- --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
- --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
- --skip_train # Skip training stages (default="${skip_train}").
- --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
- --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
- --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
- --num_nodes # The number of nodes (default="${num_nodes}").
- --nj # The number of parallel jobs (default="${nj}").
- --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
- --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
- --dumpdir # Directory to dump features (default="${dumpdir}").
- --expdir # Directory to save experiments (default="${expdir}").
- --python # Specify python to execute espnet commands (default="${python}").
- --device # Which GPUs are use for local training (defalut="${device}").
-
- # Data preparation related
- --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
-
- # Speed perturbation related
- --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
-
- # Feature extraction related
- --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
- --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
- --fs # Sampling rate (default="${fs}").
- --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
- --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
-
- # Tokenization related
- --token_type # Tokenization type (char or bpe, default="${token_type}").
- --nbpe # The number of BPE vocabulary (default="${nbpe}").
- --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
- --oov # Out of vocabulary symbol (default="${oov}").
- --blank # CTC blank symbol (default="${blank}").
- --sos_eos # sos and eos symbole (default="${sos_eos}").
- --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
- --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
- --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
-
- # Language model related
- --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
- --lm_exp # Specify the direcotry path for LM experiment.
- # If this option is specified, lm_tag is ignored (default="${lm_exp}").
- --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
- --lm_config # Config for language model training (default="${lm_config}").
- --lm_args # Arguments for language model training (default="${lm_args}").
- # e.g., --lm_args "--max_epoch 10"
- # Note that it will overwrite args in lm config.
- --use_word_lm # Whether to use word language model (default="${use_word_lm}").
- --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
- --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
-
- # ASR model related
- --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
- --asr_exp # Specify the direcotry path for ASR experiment.
- # If this option is specified, asr_tag is ignored (default="${asr_exp}").
- --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
- --asr_config # Config for asr model training (default="${asr_config}").
- --asr_args # Arguments for asr model training (default="${asr_args}").
- # e.g., --asr_args "--max_epoch 10"
- # Note that it will overwrite args in asr config.
- --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
- --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
-
- # Decoding related
- --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
- --inference_config # Config for decoding (default="${inference_config}").
- --inference_args # Arguments for decoding (default="${inference_args}").
- # e.g., --inference_args "--lm_weight 0.1"
- # Note that it will overwrite args in inference config.
- --inference_lm # Language modle path for decoding (default="${inference_lm}").
- --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
- --infer_with_pretrained_model # Use pretrained model for decoding (default="${infer_with_pretrained_model}").
- --download_sa_asr_model= # Download the SA-ASR model from ModelScope and use it for decoding(default="${download_sa_asr_model}").
-
- # [Task dependent] Set the datadir name created by local/data.sh
- --train_set # Name of training set (required).
- --valid_set # Name of validation set used for monitoring/tuning network training (required).
- --test_sets # Names of test sets.
- # Multiple items (e.g., both dev and eval sets) can be specified (required).
- --bpe_train_text # Text file path of bpe training set.
- --lm_train_text # Text file path of language model training set.
- --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
- --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
- --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
- --cleaner # Text cleaner (default="${cleaner}").
- --g2p # g2p method (default="${g2p}").
- --lang # The language type of corpus (default=${lang}).
- --score_opts # The options given to sclite scoring (default="{score_opts}").
- --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
-EOF
-)
-
-log "$0 $*"
-# Save command line args for logging (they will be lost after utils/parse_options.sh)
-run_args=$(python -m funasr.utils.cli_utils $0 "$@")
-. utils/parse_options.sh
-
-if [ $# -ne 0 ]; then
- log "${help_message}"
- log "Error: No positional arguments are required."
- exit 2
-fi
-
-. ./path.sh
-
-
-# Check required arguments
-[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
-[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
-[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
-
-# Check feature type
-if [ "${feats_type}" = raw ]; then
- data_feats=${dumpdir}/raw
-elif [ "${feats_type}" = fbank_pitch ]; then
- data_feats=${dumpdir}/fbank_pitch
-elif [ "${feats_type}" = fbank ]; then
- data_feats=${dumpdir}/fbank
-elif [ "${feats_type}" == extracted ]; then
- data_feats=${dumpdir}/extracted
-else
- log "${help_message}"
- log "Error: not supported: --feats_type ${feats_type}"
- exit 2
-fi
-
-# Use the same text as ASR for bpe training if not specified.
-[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
-# Use the same text as ASR for lm training if not specified.
-[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
-# Use the same text as ASR for lm training if not specified.
-[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
-# Use the text of the 1st evaldir if lm_test is not specified
-[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
-
-# Check tokenization type
-if [ "${lang}" != noinfo ]; then
- token_listdir=data/${lang}_token_list
-else
- token_listdir=data/token_list
-fi
-bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
-bpeprefix="${bpedir}"/bpe
-bpemodel="${bpeprefix}".model
-bpetoken_list="${bpedir}"/tokens.txt
-chartoken_list="${token_listdir}"/char/tokens.txt
-# NOTE: keep for future development.
-# shellcheck disable=SC2034
-wordtoken_list="${token_listdir}"/word/tokens.txt
-
-if [ "${token_type}" = bpe ]; then
- token_list="${bpetoken_list}"
-elif [ "${token_type}" = char ]; then
- token_list="${chartoken_list}"
- bpemodel=none
-elif [ "${token_type}" = word ]; then
- token_list="${wordtoken_list}"
- bpemodel=none
-else
- log "Error: not supported --token_type '${token_type}'"
- exit 2
-fi
-if ${use_word_lm}; then
- log "Error: Word LM is not supported yet"
- exit 2
-
- lm_token_list="${wordtoken_list}"
- lm_token_type=word
-else
- lm_token_list="${token_list}"
- lm_token_type="${token_type}"
-fi
-
-if ${infer_with_pretrained_model}; then
- skip_train=true
-fi
-
-# Set tag for naming of model directory
-if [ -z "${asr_tag}" ]; then
- if [ -n "${asr_config}" ]; then
- asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
- else
- asr_tag="train_${feats_type}"
- fi
- if [ "${lang}" != noinfo ]; then
- asr_tag+="_${lang}_${token_type}"
- else
- asr_tag+="_${token_type}"
- fi
- if [ "${token_type}" = bpe ]; then
- asr_tag+="${nbpe}"
- fi
- # Add overwritten arg's info
- if [ -n "${asr_args}" ]; then
- asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
- fi
- if [ -n "${speed_perturb_factors}" ]; then
- asr_tag+="_sp"
- fi
-fi
-if [ -z "${lm_tag}" ]; then
- if [ -n "${lm_config}" ]; then
- lm_tag="$(basename "${lm_config}" .yaml)"
- else
- lm_tag="train"
- fi
- if [ "${lang}" != noinfo ]; then
- lm_tag+="_${lang}_${lm_token_type}"
- else
- lm_tag+="_${lm_token_type}"
- fi
- if [ "${lm_token_type}" = bpe ]; then
- lm_tag+="${nbpe}"
- fi
- # Add overwritten arg's info
- if [ -n "${lm_args}" ]; then
- lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
- fi
-fi
-
-# The directory used for collect-stats mode
-if [ -z "${asr_stats_dir}" ]; then
- if [ "${lang}" != noinfo ]; then
- asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
- else
- asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
- fi
- if [ "${token_type}" = bpe ]; then
- asr_stats_dir+="${nbpe}"
- fi
- if [ -n "${speed_perturb_factors}" ]; then
- asr_stats_dir+="_sp"
- fi
-fi
-if [ -z "${lm_stats_dir}" ]; then
- if [ "${lang}" != noinfo ]; then
- lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
- else
- lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
- fi
- if [ "${lm_token_type}" = bpe ]; then
- lm_stats_dir+="${nbpe}"
- fi
-fi
-# The directory used for training commands
-if [ -z "${asr_exp}" ]; then
- asr_exp="${expdir}/asr_${asr_tag}"
-fi
-if [ -z "${lm_exp}" ]; then
- lm_exp="${expdir}/lm_${lm_tag}"
-fi
-
-
-if [ -z "${inference_tag}" ]; then
- if [ -n "${inference_config}" ]; then
- inference_tag="$(basename "${inference_config}" .yaml)"
- else
- inference_tag=inference
- fi
- # Add overwritten arg's info
- if [ -n "${inference_args}" ]; then
- inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
- fi
- if "${use_lm}"; then
- inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
- fi
- inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
-fi
-
-if [ -z "${sa_asr_inference_tag}" ]; then
- if [ -n "${inference_config}" ]; then
- sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
- else
- sa_asr_inference_tag=sa_asr_inference
- fi
- # Add overwritten arg's info
- if [ -n "${sa_asr_inference_args}" ]; then
- sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
- fi
- if "${use_lm}"; then
- sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
- fi
- sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
-fi
-
-train_cmd="run.pl"
-cuda_cmd="run.pl"
-decode_cmd="run.pl"
-
-# ========================== Main stages start from here. ==========================
-
-if ! "${skip_data_prep}"; then
-
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc."
-
- ./local/alimeeting_data_prep.sh --tgt Test
- ./local/alimeeting_data_prep.sh --tgt Eval
- ./local/alimeeting_data_prep.sh --tgt Train
- fi
-
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
- if [ -n "${speed_perturb_factors}" ]; then
- log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
- for factor in ${speed_perturb_factors}; do
- if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
- local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
- _dirs+="data/${train_set}_sp${factor} "
- else
- # If speed factor is 1, same as the original
- _dirs+="data/${train_set} "
- fi
- done
- local/combine_data.sh "data/${train_set}_sp" ${_dirs}
- else
- log "Skip stage 2: Speed perturbation"
- fi
- fi
-
- if [ -n "${speed_perturb_factors}" ]; then
- train_set="${train_set}_sp"
- fi
-
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- if [ "${feats_type}" = raw ]; then
- log "Stage 3: Format wav.scp: data/ -> ${data_feats}"
-
- # ====== Recreating "wav.scp" ======
- # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
- # shouldn't be used in training process.
- # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
- # and it can also change the audio-format and sampling rate.
- # If nothing is need, then format_wav_scp.sh does nothing:
- # i.e. the input file format and rate is same as the output.
-
- for dset in "${train_set}" "${valid_set}" "${test_sets}" ; do
- if [ "${dset}" = "${train_set}" ] || [ "${dset}" = "${valid_set}" ]; then
- _suf="/org"
- else
- if [ "${dset}" = "${test_sets}" ] && [ "${test_sets}" = "Test_Ali_far" ]; then
- _suf="/org"
- else
- _suf=""
- fi
- fi
- local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
-
- if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
- cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
- fi
-
- rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
- _opts=
- if [ -e data/"${dset}"/segments ]; then
- # "segments" is used for splitting wav files which are written in "wav".scp
- # into utterances. The file format of segments:
- # <segment_id> <record_id> <start_time> <end_time>
- # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
- # Where the time is written in seconds.
- _opts+="--segments data/${dset}/segments "
- fi
- # shellcheck disable=SC2086
- local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
- --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
- "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
-
- echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
- done
-
- else
- log "Error: not supported: --feats_type ${feats_type}"
- exit 2
- fi
- fi
-
-
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- log "Stage 4: Remove long/short data: ${data_feats}/org -> ${data_feats}"
-
- # NOTE(kamo): Not applying to test_sets to keep original data
- if [ "${test_sets}" = "Test_Ali_far" ]; then
- rm_dset="${train_set} ${valid_set} ${test_sets}"
- else
- rm_dset="${train_set} ${valid_set}"
- fi
-
- for dset in $rm_dset; do
-
- # Copy data dir
- local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
- cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
-
- # Remove short utterances
- _feats_type="$(<${data_feats}/${dset}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))")
- _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))")
- _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))")
-
- # utt2num_samples is created by format_wav_scp.sh
- <"${data_feats}/org/${dset}/utt2num_samples" \
- awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
- '{ if ($2 > min_length && $2 < max_length ) print $0; }' \
- >"${data_feats}/${dset}/utt2num_samples"
- <"${data_feats}/org/${dset}/wav.scp" \
- utils/filter_scp.pl "${data_feats}/${dset}/utt2num_samples" \
- >"${data_feats}/${dset}/wav.scp"
- else
- # Get frame shift in ms from conf/fbank.conf
- _frame_shift=
- if [ -f conf/fbank.conf ] && [ "$(<conf/fbank.conf grep -c frame-shift)" -gt 0 ]; then
- # Assume using conf/fbank.conf for feature extraction
- _frame_shift="$(<conf/fbank.conf grep frame-shift | sed -e 's/[-a-z =]*\([0-9]*\)/\1/g')"
- fi
- if [ -z "${_frame_shift}" ]; then
- # If not existing, use the default number in Kaldi (=10ms).
- # If you are using different number, you have to change the following value manually.
- _frame_shift=10
- fi
-
- _min_length=$(python3 -c "print(int(${min_wav_duration} / ${_frame_shift} * 1000))")
- _max_length=$(python3 -c "print(int(${max_wav_duration} / ${_frame_shift} * 1000))")
-
- cp "${data_feats}/org/${dset}/feats_dim" "${data_feats}/${dset}/feats_dim"
- <"${data_feats}/org/${dset}/feats_shape" awk -F, ' { print $1 } ' \
- | awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
- '{ if ($2 > min_length && $2 < max_length) print $0; }' \
- >"${data_feats}/${dset}/feats_shape"
- <"${data_feats}/org/${dset}/feats.scp" \
- utils/filter_scp.pl "${data_feats}/${dset}/feats_shape" \
- >"${data_feats}/${dset}/feats.scp"
- fi
-
- # Remove empty text
- <"${data_feats}/org/${dset}/text" \
- awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
-
- # fix_data_dir.sh leaves only utts which exist in all files
- local/fix_data_dir.sh "${data_feats}/${dset}"
-
- # generate uttid
- cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
-
- if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
- # filter utt2spk_all_fifo
- python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
- fi
- done
-
- # shellcheck disable=SC2002
- cat ${lm_train_text} | awk ' { if( NF != 1 ) print $0; } ' > "${data_feats}/lm_train.txt"
- fi
-
-
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
- log "Stage 5: Dictionary Preparation"
- mkdir -p data/${lang}_token_list/char/
-
- echo "make a dictionary"
- echo "<blank>" > ${token_list}
- echo "<s>" >> ${token_list}
- echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
- | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
- echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- fi
-
- if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
- log "Stage 6: Generate speaker settings"
- mkdir -p "profile_log"
- for dset in "${train_set}" "${valid_set}" "${test_sets}"; do
- # generate text_id spk2id
- python local/process_sot_fifo_textchar2spk.py --path ${data_feats}/${dset}
- log "Successfully generate ${data_feats}/${dset}/text_id ${data_feats}/${dset}/spk2id"
- # generate text_id_train for sot
- python local/process_text_id.py ${data_feats}/${dset}
- log "Successfully generate ${data_feats}/${dset}/text_id_train"
- # generate oracle_embedding from single-speaker audio segment
- log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log"
- python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
- log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
- # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
- if [ "${dset}" = "${train_set}" ]; then
- python local/gen_oracle_profile_padding.py ${data_feats}/${dset}
- log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_padding.scp)"
- else
- python local/gen_oracle_profile_nopadding.py ${data_feats}/${dset}
- log "Successfully generate oracle profile for ${dset} (${data_feats}/${dset}/oracle_profile_nopadding.scp)"
- fi
- # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
- if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
- log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log"
- python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
- log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
- fi
-
- done
- fi
-
-else
- log "Skip the stages for data preparation"
-fi
-
-
-# ========================== Data preparation is done here. ==========================
-
-
-if ! "${skip_train}"; then
- if "${use_lm}"; then
- if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
- log "Stage 7: LM collect stats: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
-
- _opts=
- if [ -n "${lm_config}" ]; then
- # To generate the config file: e.g.
- # % python3 -m espnet2.bin.lm_train --print_config --optim adam
- _opts+="--config ${lm_config} "
- fi
-
- # 1. Split the key file
- _logdir="${lm_stats_dir}/logdir"
- mkdir -p "${_logdir}"
- # Get the minimum number among ${nj} and the number lines of input files
- _nj=$(min "${nj}" "$(<${data_feats}/lm_train.txt wc -l)" "$(<${lm_dev_text} wc -l)")
-
- key_file="${data_feats}/lm_train.txt"
- split_scps=""
- for n in $(seq ${_nj}); do
- split_scps+=" ${_logdir}/train.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- key_file="${lm_dev_text}"
- split_scps=""
- for n in $(seq ${_nj}); do
- split_scps+=" ${_logdir}/dev.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- # 2. Generate run.sh
- log "Generate '${lm_stats_dir}/run.sh'. You can resume the process from stage 6 using this script"
- mkdir -p "${lm_stats_dir}"; echo "${run_args} --stage 6 \"\$@\"; exit \$?" > "${lm_stats_dir}/run.sh"; chmod +x "${lm_stats_dir}/run.sh"
-
- # 3. Submit jobs
- log "LM collect-stats started... log: '${_logdir}/stats.*.log'"
- # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
- # but it's used only for deciding the sample ids.
- # shellcheck disable=SC2086
- ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
- ${python} -m funasr.bin.lm_train \
- --collect_stats true \
- --use_preprocessor true \
- --bpemodel "${bpemodel}" \
- --token_type "${lm_token_type}"\
- --token_list "${lm_token_list}" \
- --non_linguistic_symbols "${nlsyms_txt}" \
- --cleaner "${cleaner}" \
- --g2p "${g2p}" \
- --train_data_path_and_name_and_type "${data_feats}/lm_train.txt,text,text" \
- --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
- --train_shape_file "${_logdir}/train.JOB.scp" \
- --valid_shape_file "${_logdir}/dev.JOB.scp" \
- --output_dir "${_logdir}/stats.JOB" \
- ${_opts} ${lm_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
-
- # 4. Aggregate shape files
- _opts=
- for i in $(seq "${_nj}"); do
- _opts+="--input_dir ${_logdir}/stats.${i} "
- done
- # shellcheck disable=SC2086
- ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${lm_stats_dir}"
-
- # Append the num-tokens at the last dimensions. This is used for batch-bins count
- <"${lm_stats_dir}/train/text_shape" \
- awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
- >"${lm_stats_dir}/train/text_shape.${lm_token_type}"
-
- <"${lm_stats_dir}/valid/text_shape" \
- awk -v N="$(<${lm_token_list} wc -l)" '{ print $0 "," N }' \
- >"${lm_stats_dir}/valid/text_shape.${lm_token_type}"
- fi
-
-
- if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
- log "Stage 8: LM Training: train_set=${data_feats}/lm_train.txt, dev_set=${lm_dev_text}"
-
- _opts=
- if [ -n "${lm_config}" ]; then
- # To generate the config file: e.g.
- # % python3 -m espnet2.bin.lm_train --print_config --optim adam
- _opts+="--config ${lm_config} "
- fi
-
- if [ "${num_splits_lm}" -gt 1 ]; then
- # If you met a memory error when parsing text files, this option may help you.
- # The corpus is split into subsets and each subset is used for training one by one in order,
- # so the memory footprint can be limited to the memory required for each dataset.
-
- _split_dir="${lm_stats_dir}/splits${num_splits_lm}"
- if [ ! -f "${_split_dir}/.done" ]; then
- rm -f "${_split_dir}/.done"
- ${python} -m espnet2.bin.split_scps \
- --scps "${data_feats}/lm_train.txt" "${lm_stats_dir}/train/text_shape.${lm_token_type}" \
- --num_splits "${num_splits_lm}" \
- --output_dir "${_split_dir}"
- touch "${_split_dir}/.done"
- else
- log "${_split_dir}/.done exists. Spliting is skipped"
- fi
-
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/lm_train.txt,text,text "
- _opts+="--train_shape_file ${_split_dir}/text_shape.${lm_token_type} "
- _opts+="--multiple_iterator true "
-
- else
- _opts+="--train_data_path_and_name_and_type ${data_feats}/lm_train.txt,text,text "
- _opts+="--train_shape_file ${lm_stats_dir}/train/text_shape.${lm_token_type} "
- fi
-
- # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
-
- log "Generate '${lm_exp}/run.sh'. You can resume the process from stage 8 using this script"
- mkdir -p "${lm_exp}"; echo "${run_args} --stage 8 \"\$@\"; exit \$?" > "${lm_exp}/run.sh"; chmod +x "${lm_exp}/run.sh"
-
- log "LM training started... log: '${lm_exp}/train.log'"
- if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
- # SGE can't include "/" in a job name
- jobname="$(basename ${lm_exp})"
- else
- jobname="${lm_exp}/train.log"
- fi
-
- mkdir -p ${lm_exp}
- mkdir -p ${lm_exp}/log
- INIT_FILE=${lm_exp}/ddp_init
- if [ -f $INIT_FILE ];then
- rm -f $INIT_FILE
- fi
- init_method=file://$(readlink -f $INIT_FILE)
- echo "$0: init method is $init_method"
- for ((i = 0; i < $ngpu; ++i)); do
- {
- # i=0
- rank=$i
- local_rank=$i
- gpu_id=$(echo $device | cut -d',' -f$[$i+1])
- lm_train.py \
- --gpu_id $gpu_id \
- --use_preprocessor true \
- --bpemodel ${bpemodel} \
- --token_type ${token_type} \
- --token_list ${token_list} \
- --non_linguistic_symbols ${nlsyms_txt} \
- --cleaner ${cleaner} \
- --g2p ${g2p} \
- --valid_data_path_and_name_and_type "${lm_dev_text},text,text" \
- --valid_shape_file "${lm_stats_dir}/valid/text_shape.${lm_token_type}" \
- --resume true \
- --output_dir ${lm_exp} \
- --config $lm_config \
- --ngpu $ngpu \
- --num_worker_count 1 \
- --multiprocessing_distributed true \
- --dist_init_method $init_method \
- --dist_world_size $ngpu \
- --dist_rank $rank \
- --local_rank $local_rank \
- ${_opts} 1> ${lm_exp}/log/train.log.$i 2>&1
- } &
- done
- wait
-
- fi
-
-
- if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
- log "Stage 9: Calc perplexity: ${lm_test_text}"
- _opts=
- # TODO(kamo): Parallelize?
- log "Perplexity calculation started... log: '${lm_exp}/perplexity_test/lm_calc_perplexity.log'"
- # shellcheck disable=SC2086
- CUDA_VISIBLE_DEVICES=${device}\
- ${cuda_cmd} --gpu "${ngpu}" "${lm_exp}"/perplexity_test/lm_calc_perplexity.log \
- ${python} -m funasr.bin.lm_calc_perplexity \
- --ngpu "${ngpu}" \
- --data_path_and_name_and_type "${lm_test_text},text,text" \
- --train_config "${lm_exp}"/config.yaml \
- --model_file "${lm_exp}/${inference_lm}" \
- --output_dir "${lm_exp}/perplexity_test" \
- ${_opts}
- log "PPL: ${lm_test_text}: $(cat ${lm_exp}/perplexity_test/ppl)"
-
- fi
-
- else
- log "Stage 7-9: Skip lm-related stages: use_lm=${use_lm}"
- fi
-
-
- if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
- _asr_train_dir="${data_feats}/${train_set}"
- _asr_valid_dir="${data_feats}/${valid_set}"
- log "Stage 10: ASR collect stats: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
-
- _opts=
- if [ -n "${asr_config}" ]; then
- # To generate the config file: e.g.
- # % python3 -m espnet2.bin.asr_train --print_config --optim adam
- _opts+="--config ${asr_config} "
- fi
-
- _feats_type="$(<${_asr_train_dir}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- # "sound" supports "wav", "flac", etc.
- _type=sound
- fi
- _opts+="--frontend_conf fs=${fs} "
- else
- _scp=feats.scp
- _type=kaldi_ark
- _input_size="$(<${_asr_train_dir}/feats_dim)"
- _opts+="--input_size=${_input_size} "
- fi
-
- # 1. Split the key file
- _logdir="${asr_stats_dir}/logdir"
- mkdir -p "${_logdir}"
-
- # Get the minimum number among ${nj} and the number lines of input files
- _nj=$(min "${nj}" "$(<${_asr_train_dir}/${_scp} wc -l)" "$(<${_asr_valid_dir}/${_scp} wc -l)")
-
- key_file="${_asr_train_dir}/${_scp}"
- split_scps=""
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/train.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- key_file="${_asr_valid_dir}/${_scp}"
- split_scps=""
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/valid.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- # 2. Generate run.sh
- log "Generate '${asr_stats_dir}/run.sh'. You can resume the process from stage 9 using this script"
- mkdir -p "${asr_stats_dir}"; echo "${run_args} --stage 9 \"\$@\"; exit \$?" > "${asr_stats_dir}/run.sh"; chmod +x "${asr_stats_dir}/run.sh"
-
- # 3. Submit jobs
- log "ASR collect-stats started... log: '${_logdir}/stats.*.log'"
-
- # NOTE: --*_shape_file doesn't require length information if --batch_type=unsorted,
- # but it's used only for deciding the sample ids.
-
- # shellcheck disable=SC2086
- ${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
- ${python} -m funasr.bin.asr_train \
- --collect_stats true \
- --mc true \
- --use_preprocessor true \
- --bpemodel "${bpemodel}" \
- --token_type "${token_type}" \
- --token_list "${token_list}" \
- --split_with_space false \
- --non_linguistic_symbols "${nlsyms_txt}" \
- --cleaner "${cleaner}" \
- --g2p "${g2p}" \
- --train_data_path_and_name_and_type "${_asr_train_dir}/${_scp},speech,${_type}" \
- --train_data_path_and_name_and_type "${_asr_train_dir}/text,text,text" \
- --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
- --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
- --train_shape_file "${_logdir}/train.JOB.scp" \
- --valid_shape_file "${_logdir}/valid.JOB.scp" \
- --output_dir "${_logdir}/stats.JOB" \
- ${_opts} ${asr_args} || { cat "${_logdir}"/stats.1.log; exit 1; }
-
- # 4. Aggregate shape files
- _opts=
- for i in $(seq "${_nj}"); do
- _opts+="--input_dir ${_logdir}/stats.${i} "
- done
- # shellcheck disable=SC2086
- ${python} -m funasr.bin.aggregate_stats_dirs ${_opts} --output_dir "${asr_stats_dir}"
-
- # Append the num-tokens at the last dimensions. This is used for batch-bins count
- <"${asr_stats_dir}/train/text_shape" \
- awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
- >"${asr_stats_dir}/train/text_shape.${token_type}"
-
- <"${asr_stats_dir}/valid/text_shape" \
- awk -v N="$(<${token_list} wc -l)" '{ print $0 "," N }' \
- >"${asr_stats_dir}/valid/text_shape.${token_type}"
- fi
-
-
- if [ ${stage} -le 11 ] && [ ${stop_stage} -ge 11 ]; then
- _asr_train_dir="${data_feats}/${train_set}"
- _asr_valid_dir="${data_feats}/${valid_set}"
- log "Stage 11: ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
-
- _opts=
- if [ -n "${asr_config}" ]; then
- # To generate the config file: e.g.
- # % python3 -m espnet2.bin.asr_train --print_config --optim adam
- _opts+="--config ${asr_config} "
- fi
-
- _feats_type="$(<${_asr_train_dir}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- # "sound" supports "wav", "flac", etc.
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- _type=sound
- fi
- _opts+="--frontend_conf fs=${fs} "
- else
- _scp=feats.scp
- _type=kaldi_ark
- _input_size="$(<${_asr_train_dir}/feats_dim)"
- _opts+="--input_size=${_input_size} "
-
- fi
- if [ "${feats_normalize}" = global_mvn ]; then
- # Default normalization is utterance_mvn and changes to global_mvn
- _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
- fi
-
- if [ "${num_splits_asr}" -gt 1 ]; then
- # If you met a memory error when parsing text files, this option may help you.
- # The corpus is split into subsets and each subset is used for training one by one in order,
- # so the memory footprint can be limited to the memory required for each dataset.
-
- _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
- if [ ! -f "${_split_dir}/.done" ]; then
- rm -f "${_split_dir}/.done"
- ${python} -m espnet2.bin.split_scps \
- --scps \
- "${_asr_train_dir}/${_scp}" \
- "${_asr_train_dir}/text" \
- "${asr_stats_dir}/train/speech_shape" \
- "${asr_stats_dir}/train/text_shape.${token_type}" \
- --num_splits "${num_splits_asr}" \
- --output_dir "${_split_dir}"
- touch "${_split_dir}/.done"
- else
- log "${_split_dir}/.done exists. Spliting is skipped"
- fi
-
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
- _opts+="--train_shape_file ${_split_dir}/speech_shape "
- _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
- _opts+="--multiple_iterator true "
-
- else
- _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
- _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
- _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
- _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
- fi
-
- # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
- # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
-
- # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
- log "ASR training started... log: '${asr_exp}/log/train.log'"
- # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
- # # SGE can't include "/" in a job name
- # jobname="$(basename ${asr_exp})"
- # else
- # jobname="${asr_exp}/train.log"
- # fi
-
- mkdir -p ${asr_exp}
- mkdir -p ${asr_exp}/log
- INIT_FILE=${asr_exp}/ddp_init
- if [ -f $INIT_FILE ];then
- rm -f $INIT_FILE
- fi
- init_method=file://$(readlink -f $INIT_FILE)
- echo "$0: init method is $init_method"
- for ((i = 0; i < $ngpu; ++i)); do
- {
- # i=0
- rank=$i
- local_rank=$i
- gpu_id=$(echo $device | cut -d',' -f$[$i+1])
- asr_train.py \
- --mc true \
- --gpu_id $gpu_id \
- --use_preprocessor true \
- --bpemodel ${bpemodel} \
- --token_type ${token_type} \
- --token_list ${token_list} \
- --split_with_space false \
- --non_linguistic_symbols ${nlsyms_txt} \
- --cleaner ${cleaner} \
- --g2p ${g2p} \
- --valid_data_path_and_name_and_type ${_asr_valid_dir}/${_scp},speech,${_type} \
- --valid_data_path_and_name_and_type ${_asr_valid_dir}/text,text,text \
- --valid_shape_file ${asr_stats_dir}/valid/speech_shape \
- --valid_shape_file ${asr_stats_dir}/valid/text_shape.${token_type} \
- --resume true \
- --output_dir ${asr_exp} \
- --config $asr_config \
- --ngpu $ngpu \
- --num_worker_count 1 \
- --multiprocessing_distributed true \
- --dist_init_method $init_method \
- --dist_world_size $ngpu \
- --dist_rank $rank \
- --local_rank $local_rank \
- ${_opts} 1> ${asr_exp}/log/train.log.$i 2>&1
- } &
- done
- wait
-
- fi
-
- if [ ${stage} -le 12 ] && [ ${stop_stage} -ge 12 ]; then
- _asr_train_dir="${data_feats}/${train_set}"
- _asr_valid_dir="${data_feats}/${valid_set}"
- log "Stage 12: SA-ASR Training: train_set=${_asr_train_dir}, valid_set=${_asr_valid_dir}"
-
- _opts=
- if [ -n "${sa_asr_config}" ]; then
- # To generate the config file: e.g.
- # % python3 -m espnet2.bin.asr_train --print_config --optim adam
- _opts+="--config ${sa_asr_config} "
- fi
-
- _feats_type="$(<${_asr_train_dir}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- # "sound" supports "wav", "flac", etc.
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- _type=sound
- fi
- _opts+="--frontend_conf fs=${fs} "
- else
- _scp=feats.scp
- _type=kaldi_ark
- _input_size="$(<${_asr_train_dir}/feats_dim)"
- _opts+="--input_size=${_input_size} "
-
- fi
- if [ "${feats_normalize}" = global_mvn ]; then
- # Default normalization is utterance_mvn and changes to global_mvn
- _opts+="--normalize=global_mvn --normalize_conf stats_file=${asr_stats_dir}/train/feats_stats.npz "
- fi
-
- if [ "${num_splits_asr}" -gt 1 ]; then
- # If you met a memory error when parsing text files, this option may help you.
- # The corpus is split into subsets and each subset is used for training one by one in order,
- # so the memory footprint can be limited to the memory required for each dataset.
-
- _split_dir="${asr_stats_dir}/splits${num_splits_asr}"
- if [ ! -f "${_split_dir}/.done" ]; then
- rm -f "${_split_dir}/.done"
- ${python} -m espnet2.bin.split_scps \
- --scps \
- "${_asr_train_dir}/${_scp}" \
- "${_asr_train_dir}/text" \
- "${asr_stats_dir}/train/speech_shape" \
- "${asr_stats_dir}/train/text_shape.${token_type}" \
- --num_splits "${num_splits_asr}" \
- --output_dir "${_split_dir}"
- touch "${_split_dir}/.done"
- else
- log "${_split_dir}/.done exists. Spliting is skipped"
- fi
-
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/${_scp},speech,${_type} "
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/text,text,text "
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/text_id_train,text_id,text_int "
- _opts+="--train_data_path_and_name_and_type ${_split_dir}/oracle_profile_padding.scp,profile,npy "
- _opts+="--train_shape_file ${_split_dir}/speech_shape "
- _opts+="--train_shape_file ${_split_dir}/text_shape.${token_type} "
- _opts+="--multiple_iterator true "
-
- else
- _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/${_scp},speech,${_type} "
- _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text,text,text "
- _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/oracle_profile_padding.scp,profile,npy "
- _opts+="--train_data_path_and_name_and_type ${_asr_train_dir}/text_id_train,text_id,text_int "
- _opts+="--train_shape_file ${asr_stats_dir}/train/speech_shape "
- _opts+="--train_shape_file ${asr_stats_dir}/train/text_shape.${token_type} "
- fi
-
- # log "Generate '${asr_exp}/run.sh'. You can resume the process from stage 10 using this script"
- # mkdir -p "${asr_exp}"; echo "${run_args} --stage 10 \"\$@\"; exit \$?" > "${asr_exp}/run.sh"; chmod +x "${asr_exp}/run.sh"
-
- # NOTE(kamo): --fold_length is used only if --batch_type=folded and it's ignored in the other case
- log "SA-ASR training started... log: '${sa_asr_exp}/log/train.log'"
- # if echo "${cuda_cmd}" | grep -e queue.pl -e queue-freegpu.pl &> /dev/null; then
- # # SGE can't include "/" in a job name
- # jobname="$(basename ${asr_exp})"
- # else
- # jobname="${asr_exp}/train.log"
- # fi
-
- mkdir -p ${sa_asr_exp}
- mkdir -p ${sa_asr_exp}/log
- INIT_FILE=${sa_asr_exp}/ddp_init
-
- if [ ! -f "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
- # download xvector extractor model file
- python local/download_xvector_model.py exp
- log "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
- fi
-
- if [ -f $INIT_FILE ];then
- rm -f $INIT_FILE
- fi
- init_method=file://$(readlink -f $INIT_FILE)
- echo "$0: init method is $init_method"
- for ((i = 0; i < $ngpu; ++i)); do
- {
- # i=0
- rank=$i
- local_rank=$i
- gpu_id=$(echo $device | cut -d',' -f$[$i+1])
- sa_asr_train.py \
- --gpu_id $gpu_id \
- --use_preprocessor true \
- --unused_parameters true \
- --bpemodel ${bpemodel} \
- --token_type ${token_type} \
- --token_list ${token_list} \
- --max_spk_num 4 \
- --split_with_space false \
- --non_linguistic_symbols ${nlsyms_txt} \
- --cleaner ${cleaner} \
- --g2p ${g2p} \
- --allow_variable_data_keys true \
- --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
- --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
- --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
- --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
- --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
- --valid_data_path_and_name_and_type "${_asr_valid_dir}/${_scp},speech,${_type}" \
- --valid_data_path_and_name_and_type "${_asr_valid_dir}/text,text,text" \
- --valid_data_path_and_name_and_type "${_asr_valid_dir}/oracle_profile_nopadding.scp,profile,npy" \
- --valid_data_path_and_name_and_type "${_asr_valid_dir}/text_id_train,text_id,text_int" \
- --valid_shape_file "${asr_stats_dir}/valid/speech_shape" \
- --valid_shape_file "${asr_stats_dir}/valid/text_shape.${token_type}" \
- --resume true \
- --output_dir ${sa_asr_exp} \
- --config $sa_asr_config \
- --ngpu $ngpu \
- --num_worker_count 1 \
- --multiprocessing_distributed true \
- --dist_init_method $init_method \
- --dist_world_size $ngpu \
- --dist_rank $rank \
- --local_rank $local_rank \
- ${_opts} 1> ${sa_asr_exp}/log/train.log.$i 2>&1
- } &
- done
- wait
-
- fi
-
-else
- log "Skip the training stages"
-fi
-
-if ${infer_with_pretrained_model}; then
- log "Use ${download_sa_asr_model} for decoding and evaluation"
- sa_asr_exp="${expdir}/${download_sa_asr_model}"
- mkdir -p "${sa_asr_exp}"
-
-
- python local/download_pretrained_model_from_modelscope.py $download_sa_asr_model ${expdir}
- inference_sa_asr_model="model.pb"
- inference_config=${sa_asr_exp}/decoding.yaml
-fi
-
-if ! "${skip_eval}"; then
- if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
- log "Stage 13: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}"
-
- if ${gpu_inference}; then
- _cmd="${cuda_cmd}"
- inference_nj=$[${ngpu}*${njob_infer}]
- _ngpu=1
-
- else
- _cmd="${decode_cmd}"
- inference_nj=$inference_nj
- _ngpu=0
- fi
-
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- if "${use_lm}"; then
- if "${use_word_lm}"; then
- _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
- _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
- else
- _opts+="--lm_train_config ${lm_exp}/config.yaml "
- _opts+="--lm_file ${lm_exp}/${inference_lm} "
- fi
- fi
-
- # 2. Generate run.sh
- log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh'. You can resume the process from stage 15 using this script"
- mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.oracle"; echo "${run_args} --stage 15 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.oracle/run.sh"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
- _logdir="${_dir}/logdir"
- mkdir -p "${_logdir}"
-
- _feats_type="$(<${_data}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- _type=sound
- fi
- else
- _scp=feats.scp
- _type=kaldi_ark
- fi
-
- # 1. Split the key file
- key_file=${_data}/${_scp}
- split_scps=""
- _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- # 2. Submit decoding jobs
- log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
- # shellcheck disable=SC2086
- ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --mc True \
- --nbest 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob_infer} \
- --gpuid_list ${device} \
- --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
- --data_path_and_name_and_type "${_data}/oracle_profile_nopadding.scp,profile,npy" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --allow_variable_data_keys true \
- --asr_train_config "${sa_asr_exp}"/config.yaml \
- --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode sa_asr \
- ${_opts}
-
-
- # 3. Concatenates the output files from each jobs
- for f in token token_int score text text_id; do
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | LC_ALL=C sort -k1 >"${_dir}/${f}"
- done
- done
- fi
-
- if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
- log "Stage 14: Scoring SA-ASR (oracle profile)"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
-
- sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
- sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
-
- python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
- python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
-
- python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
- tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
- cat ${_dir}/text.cer.txt
-
- python local/process_text_spk_merge.py ${_dir}
- python local/process_text_spk_merge.py ${_data}
-
- python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
- tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
- cat ${_dir}/text.cpcer.txt
-
- done
-
- fi
-
- if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
- log "Stage 15: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
-
- if ${gpu_inference}; then
- _cmd="${cuda_cmd}"
- inference_nj=$[${ngpu}*${njob_infer}]
- _ngpu=1
-
- else
- _cmd="${decode_cmd}"
- inference_nj=$inference_nj
- _ngpu=0
- fi
-
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- if "${use_lm}"; then
- if "${use_word_lm}"; then
- _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
- _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
- else
- _opts+="--lm_train_config ${lm_exp}/config.yaml "
- _opts+="--lm_file ${lm_exp}/${inference_lm} "
- fi
- fi
-
- # 2. Generate run.sh
- log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
- mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
- _logdir="${_dir}/logdir"
- mkdir -p "${_logdir}"
-
- _feats_type="$(<${_data}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- _type=sound
- fi
- else
- _scp=feats.scp
- _type=kaldi_ark
- fi
-
- # 1. Split the key file
- key_file=${_data}/${_scp}
- split_scps=""
- _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- # 2. Submit decoding jobs
- log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
- # shellcheck disable=SC2086
- ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --mc True \
- --nbest 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob_infer} \
- --gpuid_list ${device} \
- --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
- --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --allow_variable_data_keys true \
- --asr_train_config "${sa_asr_exp}"/config.yaml \
- --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode sa_asr \
- ${_opts}
-
- # 3. Concatenates the output files from each jobs
- for f in token token_int score text text_id; do
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | LC_ALL=C sort -k1 >"${_dir}/${f}"
- done
- done
- fi
-
- if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
- log "Stage 16: Scoring SA-ASR (cluster profile)"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
-
- sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
- sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
-
- python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
- python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
-
- python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
- tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
- cat ${_dir}/text.cer.txt
-
- python local/process_text_spk_merge.py ${_dir}
- python local/process_text_spk_merge.py ${_data}
-
- python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
- tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
- cat ${_dir}/text.cpcer.txt
-
- done
-
- fi
-
-else
- log "Skip the evaluation stages"
-fi
-
-
-log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
deleted file mode 100755
index a23215c..0000000
--- a/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
+++ /dev/null
@@ -1,591 +0,0 @@
-#!/usr/bin/env bash
-
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-log() {
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-min() {
- local a b
- a=$1
- for b in "$@"; do
- if [ "${b}" -le "${a}" ]; then
- a="${b}"
- fi
- done
- echo "${a}"
-}
-SECONDS=0
-
-# General configuration
-stage=1 # Processes starts from the specified stage.
-stop_stage=10000 # Processes is stopped at the specified stage.
-skip_data_prep=false # Skip data preparation stages.
-skip_train=false # Skip training stages.
-skip_eval=false # Skip decoding and evaluation stages.
-skip_upload=true # Skip packing and uploading stages.
-ngpu=1 # The number of gpus ("0" uses cpu, otherwise use gpu).
-num_nodes=1 # The number of nodes.
-nj=16 # The number of parallel jobs.
-inference_nj=16 # The number of parallel jobs in decoding.
-gpu_inference=false # Whether to perform gpu decoding.
-njob_infer=4
-dumpdir=dump2 # Directory to dump features.
-expdir=exp # Directory to save experiments.
-python=python3 # Specify python to execute espnet commands.
-device=0
-
-# Data preparation related
-local_data_opts= # The options given to local/data.sh.
-
-# Speed perturbation related
-speed_perturb_factors= # perturbation factors, e.g. "0.9 1.0 1.1" (separated by space).
-
-# Feature extraction related
-feats_type=raw # Feature type (raw or fbank_pitch).
-audio_format=flac # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw).
-fs=16000 # Sampling rate.
-min_wav_duration=0.1 # Minimum duration in second.
-max_wav_duration=20 # Maximum duration in second.
-
-# Tokenization related
-token_type=bpe # Tokenization type (char or bpe).
-nbpe=30 # The number of BPE vocabulary.
-bpemode=unigram # Mode of BPE (unigram or bpe).
-oov="<unk>" # Out of vocabulary symbol.
-blank="<blank>" # CTC blank symbol
-sos_eos="<sos/eos>" # sos and eos symbole
-bpe_input_sentence_size=100000000 # Size of input sentence for BPE.
-bpe_nlsyms= # non-linguistic symbols list, separated by a comma, for BPE
-bpe_char_cover=1.0 # character coverage when modeling BPE
-
-# Language model related
-use_lm=true # Use language model for ASR decoding.
-lm_tag= # Suffix to the result dir for language model training.
-lm_exp= # Specify the direcotry path for LM experiment.
- # If this option is specified, lm_tag is ignored.
-lm_stats_dir= # Specify the direcotry path for LM statistics.
-lm_config= # Config for language model training.
-lm_args= # Arguments for language model training, e.g., "--max_epoch 10".
- # Note that it will overwrite args in lm config.
-use_word_lm=false # Whether to use word language model.
-num_splits_lm=1 # Number of splitting for lm corpus.
-# shellcheck disable=SC2034
-word_vocab_size=10000 # Size of word vocabulary.
-
-# ASR model related
-asr_tag= # Suffix to the result dir for asr model training.
-asr_exp= # Specify the direcotry path for ASR experiment.
- # If this option is specified, asr_tag is ignored.
-sa_asr_exp=
-asr_stats_dir= # Specify the direcotry path for ASR statistics.
-asr_config= # Config for asr model training.
-sa_asr_config=
-asr_args= # Arguments for asr model training, e.g., "--max_epoch 10".
- # Note that it will overwrite args in asr config.
-feats_normalize=global_mvn # Normalizaton layer type.
-num_splits_asr=1 # Number of splitting for lm corpus.
-
-# Decoding related
-inference_tag= # Suffix to the result dir for decoding.
-inference_config= # Config for decoding.
-inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
- # Note that it will overwrite args in inference config.
-sa_asr_inference_tag=
-sa_asr_inference_args=
-
-inference_lm=valid.loss.ave.pb # Language modle path for decoding.
-inference_asr_model=valid.acc.ave.pb # ASR model path for decoding.
- # e.g.
- # inference_asr_model=train.loss.best.pth
- # inference_asr_model=3epoch.pth
- # inference_asr_model=valid.acc.best.pth
- # inference_asr_model=valid.loss.ave.pth
-inference_sa_asr_model=valid.acc_spk.ave.pb
-download_model= # Download a model from Model Zoo and use it for decoding.
-
-# [Task dependent] Set the datadir name created by local/data.sh
-train_set= # Name of training set.
-valid_set= # Name of validation set used for monitoring/tuning network training.
-test_sets= # Names of test sets. Multiple items (e.g., both dev and eval sets) can be specified.
-bpe_train_text= # Text file path of bpe training set.
-lm_train_text= # Text file path of language model training set.
-lm_dev_text= # Text file path of language model development set.
-lm_test_text= # Text file path of language model evaluation set.
-nlsyms_txt=none # Non-linguistic symbol list if existing.
-cleaner=none # Text cleaner.
-g2p=none # g2p method (needed if token_type=phn).
-lang=zh # The language type of corpus.
-score_opts= # The options given to sclite scoring
-local_score_opts= # The options given to local/score.sh.
-
-help_message=$(cat << EOF
-Usage: $0 --train-set "<train_set_name>" --valid-set "<valid_set_name>" --test_sets "<test_set_names>"
-
-Options:
- # General configuration
- --stage # Processes starts from the specified stage (default="${stage}").
- --stop_stage # Processes is stopped at the specified stage (default="${stop_stage}").
- --skip_data_prep # Skip data preparation stages (default="${skip_data_prep}").
- --skip_train # Skip training stages (default="${skip_train}").
- --skip_eval # Skip decoding and evaluation stages (default="${skip_eval}").
- --skip_upload # Skip packing and uploading stages (default="${skip_upload}").
- --ngpu # The number of gpus ("0" uses cpu, otherwise use gpu, default="${ngpu}").
- --num_nodes # The number of nodes (default="${num_nodes}").
- --nj # The number of parallel jobs (default="${nj}").
- --inference_nj # The number of parallel jobs in decoding (default="${inference_nj}").
- --gpu_inference # Whether to perform gpu decoding (default="${gpu_inference}").
- --dumpdir # Directory to dump features (default="${dumpdir}").
- --expdir # Directory to save experiments (default="${expdir}").
- --python # Specify python to execute espnet commands (default="${python}").
- --device # Which GPUs are use for local training (defalut="${device}").
-
- # Data preparation related
- --local_data_opts # The options given to local/data.sh (default="${local_data_opts}").
-
- # Speed perturbation related
- --speed_perturb_factors # speed perturbation factors, e.g. "0.9 1.0 1.1" (separated by space, default="${speed_perturb_factors}").
-
- # Feature extraction related
- --feats_type # Feature type (raw, fbank_pitch or extracted, default="${feats_type}").
- --audio_format # Audio format: wav, flac, wav.ark, flac.ark (only in feats_type=raw, default="${audio_format}").
- --fs # Sampling rate (default="${fs}").
- --min_wav_duration # Minimum duration in second (default="${min_wav_duration}").
- --max_wav_duration # Maximum duration in second (default="${max_wav_duration}").
-
- # Tokenization related
- --token_type # Tokenization type (char or bpe, default="${token_type}").
- --nbpe # The number of BPE vocabulary (default="${nbpe}").
- --bpemode # Mode of BPE (unigram or bpe, default="${bpemode}").
- --oov # Out of vocabulary symbol (default="${oov}").
- --blank # CTC blank symbol (default="${blank}").
- --sos_eos # sos and eos symbole (default="${sos_eos}").
- --bpe_input_sentence_size # Size of input sentence for BPE (default="${bpe_input_sentence_size}").
- --bpe_nlsyms # Non-linguistic symbol list for sentencepiece, separated by a comma. (default="${bpe_nlsyms}").
- --bpe_char_cover # Character coverage when modeling BPE (default="${bpe_char_cover}").
-
- # Language model related
- --lm_tag # Suffix to the result dir for language model training (default="${lm_tag}").
- --lm_exp # Specify the direcotry path for LM experiment.
- # If this option is specified, lm_tag is ignored (default="${lm_exp}").
- --lm_stats_dir # Specify the direcotry path for LM statistics (default="${lm_stats_dir}").
- --lm_config # Config for language model training (default="${lm_config}").
- --lm_args # Arguments for language model training (default="${lm_args}").
- # e.g., --lm_args "--max_epoch 10"
- # Note that it will overwrite args in lm config.
- --use_word_lm # Whether to use word language model (default="${use_word_lm}").
- --word_vocab_size # Size of word vocabulary (default="${word_vocab_size}").
- --num_splits_lm # Number of splitting for lm corpus (default="${num_splits_lm}").
-
- # ASR model related
- --asr_tag # Suffix to the result dir for asr model training (default="${asr_tag}").
- --asr_exp # Specify the direcotry path for ASR experiment.
- # If this option is specified, asr_tag is ignored (default="${asr_exp}").
- --asr_stats_dir # Specify the direcotry path for ASR statistics (default="${asr_stats_dir}").
- --asr_config # Config for asr model training (default="${asr_config}").
- --asr_args # Arguments for asr model training (default="${asr_args}").
- # e.g., --asr_args "--max_epoch 10"
- # Note that it will overwrite args in asr config.
- --feats_normalize # Normalizaton layer type (default="${feats_normalize}").
- --num_splits_asr # Number of splitting for lm corpus (default="${num_splits_asr}").
-
- # Decoding related
- --inference_tag # Suffix to the result dir for decoding (default="${inference_tag}").
- --inference_config # Config for decoding (default="${inference_config}").
- --inference_args # Arguments for decoding (default="${inference_args}").
- # e.g., --inference_args "--lm_weight 0.1"
- # Note that it will overwrite args in inference config.
- --inference_lm # Language modle path for decoding (default="${inference_lm}").
- --inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
- --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
-
- # [Task dependent] Set the datadir name created by local/data.sh
- --train_set # Name of training set (required).
- --valid_set # Name of validation set used for monitoring/tuning network training (required).
- --test_sets # Names of test sets.
- # Multiple items (e.g., both dev and eval sets) can be specified (required).
- --bpe_train_text # Text file path of bpe training set.
- --lm_train_text # Text file path of language model training set.
- --lm_dev_text # Text file path of language model development set (default="${lm_dev_text}").
- --lm_test_text # Text file path of language model evaluation set (default="${lm_test_text}").
- --nlsyms_txt # Non-linguistic symbol list if existing (default="${nlsyms_txt}").
- --cleaner # Text cleaner (default="${cleaner}").
- --g2p # g2p method (default="${g2p}").
- --lang # The language type of corpus (default=${lang}).
- --score_opts # The options given to sclite scoring (default="{score_opts}").
- --local_score_opts # The options given to local/score.sh (default="{local_score_opts}").
-EOF
-)
-
-log "$0 $*"
-# Save command line args for logging (they will be lost after utils/parse_options.sh)
-run_args=$(python -m funasr.utils.cli_utils $0 "$@")
-. utils/parse_options.sh
-
-if [ $# -ne 0 ]; then
- log "${help_message}"
- log "Error: No positional arguments are required."
- exit 2
-fi
-
-. ./path.sh
-
-
-# Check required arguments
-[ -z "${train_set}" ] && { log "${help_message}"; log "Error: --train_set is required"; exit 2; };
-[ -z "${valid_set}" ] && { log "${help_message}"; log "Error: --valid_set is required"; exit 2; };
-[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };
-
-# Check feature type
-if [ "${feats_type}" = raw ]; then
- data_feats=${dumpdir}/raw
-elif [ "${feats_type}" = fbank_pitch ]; then
- data_feats=${dumpdir}/fbank_pitch
-elif [ "${feats_type}" = fbank ]; then
- data_feats=${dumpdir}/fbank
-elif [ "${feats_type}" == extracted ]; then
- data_feats=${dumpdir}/extracted
-else
- log "${help_message}"
- log "Error: not supported: --feats_type ${feats_type}"
- exit 2
-fi
-
-# Use the same text as ASR for bpe training if not specified.
-[ -z "${bpe_train_text}" ] && bpe_train_text="${data_feats}/${train_set}/text"
-# Use the same text as ASR for lm training if not specified.
-[ -z "${lm_train_text}" ] && lm_train_text="${data_feats}/${train_set}/text"
-# Use the same text as ASR for lm training if not specified.
-[ -z "${lm_dev_text}" ] && lm_dev_text="${data_feats}/${valid_set}/text"
-# Use the text of the 1st evaldir if lm_test is not specified
-[ -z "${lm_test_text}" ] && lm_test_text="${data_feats}/${test_sets%% *}/text"
-
-# Check tokenization type
-if [ "${lang}" != noinfo ]; then
- token_listdir=data/${lang}_token_list
-else
- token_listdir=data/token_list
-fi
-bpedir="${token_listdir}/bpe_${bpemode}${nbpe}"
-bpeprefix="${bpedir}"/bpe
-bpemodel="${bpeprefix}".model
-bpetoken_list="${bpedir}"/tokens.txt
-chartoken_list="${token_listdir}"/char/tokens.txt
-# NOTE: keep for future development.
-# shellcheck disable=SC2034
-wordtoken_list="${token_listdir}"/word/tokens.txt
-
-if [ "${token_type}" = bpe ]; then
- token_list="${bpetoken_list}"
-elif [ "${token_type}" = char ]; then
- token_list="${chartoken_list}"
- bpemodel=none
-elif [ "${token_type}" = word ]; then
- token_list="${wordtoken_list}"
- bpemodel=none
-else
- log "Error: not supported --token_type '${token_type}'"
- exit 2
-fi
-if ${use_word_lm}; then
- log "Error: Word LM is not supported yet"
- exit 2
-
- lm_token_list="${wordtoken_list}"
- lm_token_type=word
-else
- lm_token_list="${token_list}"
- lm_token_type="${token_type}"
-fi
-
-
-# Set tag for naming of model directory
-if [ -z "${asr_tag}" ]; then
- if [ -n "${asr_config}" ]; then
- asr_tag="$(basename "${asr_config}" .yaml)_${feats_type}"
- else
- asr_tag="train_${feats_type}"
- fi
- if [ "${lang}" != noinfo ]; then
- asr_tag+="_${lang}_${token_type}"
- else
- asr_tag+="_${token_type}"
- fi
- if [ "${token_type}" = bpe ]; then
- asr_tag+="${nbpe}"
- fi
- # Add overwritten arg's info
- if [ -n "${asr_args}" ]; then
- asr_tag+="$(echo "${asr_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
- fi
- if [ -n "${speed_perturb_factors}" ]; then
- asr_tag+="_sp"
- fi
-fi
-if [ -z "${lm_tag}" ]; then
- if [ -n "${lm_config}" ]; then
- lm_tag="$(basename "${lm_config}" .yaml)"
- else
- lm_tag="train"
- fi
- if [ "${lang}" != noinfo ]; then
- lm_tag+="_${lang}_${lm_token_type}"
- else
- lm_tag+="_${lm_token_type}"
- fi
- if [ "${lm_token_type}" = bpe ]; then
- lm_tag+="${nbpe}"
- fi
- # Add overwritten arg's info
- if [ -n "${lm_args}" ]; then
- lm_tag+="$(echo "${lm_args}" | sed -e "s/--/\_/g" -e "s/[ |=/]//g")"
- fi
-fi
-
-# The directory used for collect-stats mode
-if [ -z "${asr_stats_dir}" ]; then
- if [ "${lang}" != noinfo ]; then
- asr_stats_dir="${expdir}/asr_stats_${feats_type}_${lang}_${token_type}"
- else
- asr_stats_dir="${expdir}/asr_stats_${feats_type}_${token_type}"
- fi
- if [ "${token_type}" = bpe ]; then
- asr_stats_dir+="${nbpe}"
- fi
- if [ -n "${speed_perturb_factors}" ]; then
- asr_stats_dir+="_sp"
- fi
-fi
-if [ -z "${lm_stats_dir}" ]; then
- if [ "${lang}" != noinfo ]; then
- lm_stats_dir="${expdir}/lm_stats_${lang}_${lm_token_type}"
- else
- lm_stats_dir="${expdir}/lm_stats_${lm_token_type}"
- fi
- if [ "${lm_token_type}" = bpe ]; then
- lm_stats_dir+="${nbpe}"
- fi
-fi
-# The directory used for training commands
-if [ -z "${asr_exp}" ]; then
- asr_exp="${expdir}/asr_${asr_tag}"
-fi
-if [ -z "${lm_exp}" ]; then
- lm_exp="${expdir}/lm_${lm_tag}"
-fi
-
-
-if [ -z "${inference_tag}" ]; then
- if [ -n "${inference_config}" ]; then
- inference_tag="$(basename "${inference_config}" .yaml)"
- else
- inference_tag=inference
- fi
- # Add overwritten arg's info
- if [ -n "${inference_args}" ]; then
- inference_tag+="$(echo "${inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
- fi
- if "${use_lm}"; then
- inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
- fi
- inference_tag+="_asr_model_$(echo "${inference_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
-fi
-
-if [ -z "${sa_asr_inference_tag}" ]; then
- if [ -n "${inference_config}" ]; then
- sa_asr_inference_tag="$(basename "${inference_config}" .yaml)"
- else
- sa_asr_inference_tag=sa_asr_inference
- fi
- # Add overwritten arg's info
- if [ -n "${sa_asr_inference_args}" ]; then
- sa_asr_inference_tag+="$(echo "${sa_asr_inference_args}" | sed -e "s/--/\_/g" -e "s/[ |=]//g")"
- fi
- if "${use_lm}"; then
- sa_asr_inference_tag+="_lm_$(basename "${lm_exp}")_$(echo "${inference_lm}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
- fi
- sa_asr_inference_tag+="_asr_model_$(echo "${inference_sa_asr_model}" | sed -e "s/\//_/g" -e "s/\.[^.]*$//g")"
-fi
-
-train_cmd="run.pl"
-cuda_cmd="run.pl"
-decode_cmd="run.pl"
-
-# ========================== Main stages start from here. ==========================
-
-if ! "${skip_data_prep}"; then
-
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- if [ "${feats_type}" = raw ]; then
- log "Stage 1: Format wav.scp: data/ -> ${data_feats}"
-
- # ====== Recreating "wav.scp" ======
- # Kaldi-wav.scp, which can describe the file path with unix-pipe, like "cat /some/path |",
- # shouldn't be used in training process.
- # "format_wav_scp.sh" dumps such pipe-style-wav to real audio file
- # and it can also change the audio-format and sampling rate.
- # If nothing is need, then format_wav_scp.sh does nothing:
- # i.e. the input file format and rate is same as the output.
-
- for dset in "${test_sets}" ; do
-
- _suf=""
-
- local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
-
- rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
- _opts=
- if [ -e data/"${dset}"/segments ]; then
- # "segments" is used for splitting wav files which are written in "wav".scp
- # into utterances. The file format of segments:
- # <segment_id> <record_id> <start_time> <end_time>
- # "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5"
- # Where the time is written in seconds.
- _opts+="--segments data/${dset}/segments "
- fi
- # shellcheck disable=SC2086
- local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
- --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
- "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
-
- echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"
- done
-
- else
- log "Error: not supported: --feats_type ${feats_type}"
- exit 2
- fi
- fi
-
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
- log "Stage 2: Generate speaker profile by spectral-cluster"
- mkdir -p "profile_log"
- for dset in "${test_sets}"; do
- # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
- python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
- log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
- done
- fi
-
-else
- log "Skip the stages for data preparation"
-fi
-
-
-# ========================== Data preparation is done here. ==========================
-
-if ! "${skip_eval}"; then
-
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- log "Stage 3: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
-
- if ${gpu_inference}; then
- _cmd="${cuda_cmd}"
- inference_nj=$[${ngpu}*${njob_infer}]
- _ngpu=1
-
- else
- _cmd="${decode_cmd}"
- inference_nj=$njob_infer
- _ngpu=0
- fi
-
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- if "${use_lm}"; then
- if "${use_word_lm}"; then
- _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
- _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
- else
- _opts+="--lm_train_config ${lm_exp}/config.yaml "
- _opts+="--lm_file ${lm_exp}/${inference_lm} "
- fi
- fi
-
- # 2. Generate run.sh
- log "Generate '${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh'. You can resume the process from stage 17 using this script"
- mkdir -p "${sa_asr_exp}/${sa_asr_inference_tag}.cluster"; echo "${run_args} --stage 17 \"\$@\"; exit \$?" > "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"; chmod +x "${sa_asr_exp}/${sa_asr_inference_tag}.cluster/run.sh"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
- _logdir="${_dir}/logdir"
- mkdir -p "${_logdir}"
-
- _feats_type="$(<${_data}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- _type=sound
- fi
- else
- _scp=feats.scp
- _type=kaldi_ark
- fi
-
- # 1. Split the key file
- key_file=${_data}/${_scp}
- split_scps=""
- _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- # 2. Submit decoding jobs
- log "Decoding started... log: '${_logdir}/sa_asr_inference.*.log'"
- # shellcheck disable=SC2086
- ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --mc True \
- --nbest 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob_infer} \
- --gpuid_list ${device} \
- --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
- --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --allow_variable_data_keys true \
- --asr_train_config "${sa_asr_exp}"/config.yaml \
- --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode sa_asr \
- ${_opts}
-
- # 3. Concatenates the output files from each jobs
- for f in token token_int score text text_id; do
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | LC_ALL=C sort -k1 >"${_dir}/${f}"
- done
- done
- fi
-
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- log "Stage 4: Generate SA-ASR results (cluster profile)"
-
- for dset in ${test_sets}; do
- _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
-
- python local/process_text_spk_merge.py ${_dir}
- done
-
- fi
-
-else
- log "Skip the evaluation stages"
-fi
-
-
-log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
deleted file mode 100644
index 68520ae..0000000
--- a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-lm: transformer
-lm_conf:
- pos_enc: null
- embed_unit: 128
- att_unit: 512
- head: 8
- unit: 2048
- layer: 16
- dropout_rate: 0.1
-
-# optimization related
-grad_clip: 5.0
-batch_type: numel
-batch_bins: 500000 # 4gpus * 500000
-accum_grad: 1
-max_epoch: 15 # 15epoch is enougth
-
-optim: adam
-optim_conf:
- lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 25000
-
-best_model_criterion:
-- - valid
- - loss
- - min
-keep_nbest_models: 10 # 10 is good.
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh
deleted file mode 100755
index dfc2b78..0000000
--- a/egs/alimeeting/sa-asr/path.sh
+++ /dev/null
@@ -1,5 +0,0 @@
-export FUNASR_DIR=$PWD/../../..
-
-# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
-export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:./utils:$PATH
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/run.sh b/egs/alimeeting/sa-asr/run.sh
deleted file mode 100755
index c74df56..0000000
--- a/egs/alimeeting/sa-asr/run.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env bash
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-ngpu=4
-device="0,1,2,3"
-
-stage=1
-stop_stage=16
-
-
-train_set=Train_Ali_far
-valid_set=Eval_Ali_far
-test_sets="Test_Ali_far"
-asr_config=conf/train_asr_conformer.yaml
-sa_asr_config=conf/train_sa_asr_conformer.yaml
-inference_config=conf/decode_asr_rnn.yaml
-infer_with_pretrained_model=false
-download_sa_asr_model="damo/speech_saasr_asr-zh-cn-16k-alimeeting"
-
-lm_config=conf/train_lm_transformer.yaml
-use_lm=false
-use_wordlm=false
-./asr_local.sh \
- --device ${device} \
- --ngpu ${ngpu} \
- --stage ${stage} \
- --stop_stage ${stop_stage} \
- --gpu_inference true \
- --njob_infer 4 \
- --infer_with_pretrained_model ${infer_with_pretrained_model} \
- --download_sa_asr_model $download_sa_asr_model \
- --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
- --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
- --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
- --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
- --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
- --lang zh \
- --audio_format wav \
- --feats_type raw \
- --token_type char \
- --use_lm ${use_lm} \
- --use_word_lm ${use_wordlm} \
- --lm_config "${lm_config}" \
- --asr_config "${asr_config}" \
- --sa_asr_config "${sa_asr_config}" \
- --inference_config "${inference_config}" \
- --train_set "${train_set}" \
- --valid_set "${valid_set}" \
- --test_sets "${test_sets}" \
- --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
deleted file mode 100755
index 1967864..0000000
--- a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
+++ /dev/null
@@ -1,50 +0,0 @@
-#!/usr/bin/env bash
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-ngpu=4
-device="0,1,2,3"
-
-stage=1
-stop_stage=4
-
-
-train_set=Train_Ali_far
-valid_set=Eval_Ali_far
-test_sets="Test_2023_Ali_far"
-asr_config=conf/train_asr_conformer.yaml
-sa_asr_config=conf/train_sa_asr_conformer.yaml
-inference_config=conf/decode_asr_rnn.yaml
-
-lm_config=conf/train_lm_transformer.yaml
-use_lm=false
-use_wordlm=false
-./asr_local_m2met_2023_infer.sh \
- --device ${device} \
- --ngpu ${ngpu} \
- --stage ${stage} \
- --stop_stage ${stop_stage} \
- --gpu_inference true \
- --njob_infer 4 \
- --asr_exp exp/asr_train_multispeaker_conformer_raw_zh_char_data_alimeeting \
- --sa_asr_exp exp/sa_asr_train_conformer_raw_zh_char_data_alimeeting \
- --asr_stats_dir exp/asr_stats_multispeaker_conformer_raw_zh_char_data_alimeeting \
- --lm_exp exp/lm_train_multispeaker_transformer_zh_char_data_alimeeting \
- --lm_stats_dir exp/lm_stats_multispeaker_zh_char_data_alimeeting \
- --lang zh \
- --audio_format wav \
- --feats_type raw \
- --token_type char \
- --use_lm ${use_lm} \
- --use_word_lm ${use_wordlm} \
- --lm_config "${lm_config}" \
- --asr_config "${asr_config}" \
- --sa_asr_config "${sa_asr_config}" \
- --inference_config "${inference_config}" \
- --train_set "${train_set}" \
- --valid_set "${valid_set}" \
- --test_sets "${test_sets}" \
- --lm_train_text "data/${train_set}/text" "$@"
diff --git a/egs/alimeeting/sa_asr/README.md b/egs/alimeeting/sa_asr/README.md
new file mode 100644
index 0000000..1ae023a
--- /dev/null
+++ b/egs/alimeeting/sa_asr/README.md
@@ -0,0 +1,86 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
+# Train
+First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
+After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
+```shell
+dataset
+|鈥斺�� Eval_Ali_far
+|鈥斺�� Eval_Ali_near
+|鈥斺�� Test_Ali_far
+|鈥斺�� Test_Ali_near
+|鈥斺�� Train_Ali_far
+|鈥斺�� Train_Ali_near
+```
+Then you can run this receipe by running:
+```shell
+bash run.sh --stage 0 --stop-stage 6
+```
+There are 8 stages in `run.sh`:
+```shell
+stage 0: Data preparation and remove the audio which is too long or too short.
+stage 1: Speaker profile and CMVN Generation.
+stage 2: Dictionary preparation.
+stage 3: LM training (not supported).
+stage 4: ASR Training.
+stage 5: SA-ASR Training.
+stage 6: Inference
+stage 7: Inference with Test_2023_Ali_far
+```
+<!-- The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary). -->
+# Infer
+1. Download the final test set and extracted
+2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
+3. Set the `test_2023` in `run.sh` should be to `Test_2023_Ali_far`.
+4. Run the `run.sh` as follow.
+```shell
+# Prepare test_2023 set
+bash run.sh --stage 0 --stop-stage 1
+# Decode test_2023 set
+bash run.sh --stage 7 --stop-stage 7
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.
+<!-- <table>
+ <tr >
+ <td rowspan="2"></td>
+ <td colspan="2">SI-CER(%)</td>
+ <td colspan="2">cpCER(%)</td>
+ </tr>
+ <tr>
+ <td>Eval</td>
+ <td>Test</td>
+ <td>Eval</td>
+ <td>Test</td>
+ </tr>
+ <tr>
+ <td>oracle profile</td>
+ <td>32.05</td>
+ <td>32.72</td>
+ <td>47.40</td>
+ <td>42.92</td>
+ </tr>
+ <tr>
+ <td>cluster profile</td>
+ <td>32.05</td>
+ <td>32.73</td>
+ <td>53.76</td>
+ <td>49.37</td>
+ </tr>
+</table> -->
+| |SI-CER(%) |cp-CER(%) |
+|:---------------|:------------:|----------:|
+|oracle profile |32.72 |42.92 |
+|cluster profile|32.73 |49.37 |
+
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413鈥�4417.
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
rename to egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
similarity index 81%
rename from egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
rename to egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
index a8c9968..507ad30 100644
--- a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
+++ b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
@@ -1,9 +1,14 @@
# network architecture
-frontend: default
+frontend: multichannelfrontend
frontend_conf:
+ fs: 16000
+ window: hann
n_fft: 400
- win_length: 400
- hop_length: 160
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
use_channel: 0
# encoder related
@@ -47,9 +52,18 @@
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# minibatch related
-batch_type: numel
-batch_bins: 10000000 # reduce/increase this number according to your GPU memory
+
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 7000
+ num_workers: 8
# optimization related
accum_grad: 1
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
similarity index 81%
rename from egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
rename to egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
index 612fd23..18614dd 100644
--- a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
+++ b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
@@ -1,10 +1,16 @@
# network architecture
-frontend: default
+frontend: multichannelfrontend
frontend_conf:
+ fs: 16000
+ window: hann
n_fft: 400
- win_length: 400
- hop_length: 160
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
use_channel: 0
+ mc: False
# encoder related
asr_encoder: conformer
@@ -44,6 +50,7 @@
pooling_type: statistic
num_nodes_resnet1: 256
num_nodes_last_layer: 256
+ batchnorm_momentum: 0.5
# decoder related
decoder: sa_decoder
@@ -63,13 +70,23 @@
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
+ max_spk_num: 4
ctc_conf:
ignore_nan_grad: true
# minibatch related
-batch_type: numel
-batch_bins: 10000000
+dataset_conf:
+ data_names: speech,text,profile,text_id
+ data_types: sound,text,npy,text_int
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 7000
+ num_workers: 8
# optimization related
accum_grad: 1
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
similarity index 74%
rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
index c13ee42..fd76837 100755
--- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
+++ b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
@@ -21,6 +21,8 @@
SECONDS=0
tgt=Train #Train or Eval
+min_wav_duration=0.1
+max_wav_duration=20
log "$0 $*"
@@ -57,27 +59,24 @@
stop_stage=4
mkdir -p $far_dir
mkdir -p $near_dir
+mkdir -p data/org
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1:process alimeeting near dir"
find -L $near_raw_dir/audio_dir -iname "*.wav" | sort > $near_dir/wavlist
- awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
- find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
+ awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid
+ find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
n1_wav=$(wc -l < $near_dir/wavlist)
n2_text=$(wc -l < $near_dir/textgrid.flist)
log near file found $n1_wav wav and $n2_text text.
- paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
-
- # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
- cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+ paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
- #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
- #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
@@ -97,9 +96,7 @@
n2_text=$(wc -l < $far_dir/textgrid.flist)
log far file found $n1_wav wav and $n2_text text.
- paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
-
- cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+ paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
python local/alimeeting_process_overlap_force.py --path $far_dir \
--no-overlap false --mars True \
@@ -119,28 +116,28 @@
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- log "stage 3: finali data process"
+ log "stage 3: final data process"
local/fix_data_dir.sh $near_dir
local/fix_data_dir.sh $far_dir
- local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
- local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+ local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
+ local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
- sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
- sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+ sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
# remove space in text
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
- cp data/${x}/text data/${x}/text.org
- paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
- > data/${x}/text
- rm data/${x}/text.org
+ cp data/org/${x}/text data/org/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+ > data/org/${x}/text
+ rm data/org/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+ log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
cp -r $far_dir/* $far_single_speaker_dir
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
@@ -150,14 +147,15 @@
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./local/fix_data_dir.sh $far_single_speaker_dir
- local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+ local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
- cp data/${x}/text data/${x}/text.org
- paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
- > data/${x}/text
- rm data/${x}/text.org
+ cp data/org/${x}/text data/org/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+ > data/org/${x}/text
+ rm data/org/${x}/text.org
done
+ rm -rf data/local
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
rename to egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
rename to egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
diff --git a/egs/alimeeting/sa-asr/local/apply_map.pl b/egs/alimeeting/sa_asr/local/apply_map.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/apply_map.pl
rename to egs/alimeeting/sa_asr/local/apply_map.pl
diff --git a/egs/alimeeting/sa-asr/local/combine_data.sh b/egs/alimeeting/sa_asr/local/combine_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/combine_data.sh
rename to egs/alimeeting/sa_asr/local/combine_data.sh
diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.py b/egs/alimeeting/sa_asr/local/compute_cmvn.py
new file mode 100755
index 0000000..d16563a
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/compute_cmvn.py
@@ -0,0 +1,134 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+import yaml
+from funasr.models.frontend.default import DefaultFrontend
+import torch
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="computer global cmvn",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--dim",
+ default=80,
+ type=int,
+ help="feature dimension",
+ )
+ parser.add_argument(
+ "--wav_path",
+ default=False,
+ required=True,
+ type=str,
+ help="the path of wav scps",
+ )
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ help="the config file for computing cmvn",
+ )
+ parser.add_argument(
+ "--idx",
+ default=1,
+ required=True,
+ type=int,
+ help="index",
+ )
+ return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
+
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
+ total_frames = 0
+
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+
+ with open(args.config_file) as f:
+ configs = yaml.safe_load(f)
+ frontend_configs = configs.get("frontend_conf", {})
+ num_mel_bins = frontend_configs.get("n_mels", 80)
+ frame_length = frontend_configs.get("frame_length", 25)
+ frame_shift = frontend_configs.get("frame_shift", 10)
+ window_type = frontend_configs.get("window", "hamming")
+ resample_rate = frontend_configs.get("fs", 16000)
+ n_fft = frontend_configs.get("n_fft", "400")
+ use_channel = frontend_configs.get("use_channel", None)
+ assert num_mel_bins == args.dim
+ frontend = DefaultFrontend(
+ fs=resample_rate,
+ n_fft=n_fft,
+ win_length=frame_length * 16,
+ hop_length=frame_shift * 16,
+ window=window_type,
+ n_mels=num_mel_bins,
+ use_channel=use_channel,
+ )
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ wavform, _ = torchaudio.load(wav_file)
+ fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
+ fbank = fbank.squeeze(0).numpy()
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
+
+ cmvn_info = {
+ 'mean_stats': list(mean_stats.tolist()),
+ 'var_stats': list(var_stats.tolist()),
+ 'total_frames': total_frames
+ }
+ with open(cmvn_file, 'w') as fout:
+ fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.sh b/egs/alimeeting/sa_asr/local/compute_cmvn.sh
new file mode 100755
index 0000000..00d08d1
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/compute_cmvn.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+fbankdir=
+nj=32
+cmd=./utils/run.pl
+feats_dim=80
+config_file=
+scale=1.0
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+# shellcheck disable=SC2046
+head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
+
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
+
+logdir=${fbankdir}/cmvn/log
+$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
+ python local/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --config_file $config_file \
+ --idx JOB \
+
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
+
+echo "$0: Succeeded compute global cmvn"
diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa_asr/local/compute_cpcer.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/compute_cpcer.py
rename to egs/alimeeting/sa_asr/local/compute_cpcer.py
diff --git a/egs/alimeeting/sa_asr/local/convert_model.py b/egs/alimeeting/sa_asr/local/convert_model.py
new file mode 100644
index 0000000..f0f7997
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/convert_model.py
@@ -0,0 +1,29 @@
+import codecs
+import pdb
+import sys
+import torch
+
+char1 = sys.argv[1]
+char2 = sys.argv[2]
+model1 = torch.load(sys.argv[3], map_location='cpu')
+model2_path = sys.argv[4]
+
+d_new = model1
+char1_list = []
+map_list = []
+
+
+with codecs.open(char1) as f:
+ for line in f.readlines():
+ char1_list.append(line.strip())
+
+with codecs.open(char2) as f:
+ for line in f.readlines():
+ map_list.append(char1_list.index(line.strip()))
+print(map_list)
+
+for k, v in d_new.items():
+ if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
+ d_new[k] = v[map_list]
+
+torch.save(d_new, model2_path)
diff --git a/egs/alimeeting/sa-asr/local/copy_data_dir.sh b/egs/alimeeting/sa_asr/local/copy_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/copy_data_dir.sh
rename to egs/alimeeting/sa_asr/local/copy_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh b/egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
rename to egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh b/egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
rename to egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh b/egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
rename to egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
diff --git a/egs/alimeeting/sa-asr/local/data/split_data.sh b/egs/alimeeting/sa_asr/local/data/split_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/split_data.sh
rename to egs/alimeeting/sa_asr/local/data/split_data.sh
diff --git a/egs/alimeeting/sa_asr/local/download_and_untar.sh b/egs/alimeeting/sa_asr/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py b/egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py
rename to egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa_asr/local/download_xvector_model.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/download_xvector_model.py
rename to egs/alimeeting/sa_asr/local/download_xvector_model.py
diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
rename to egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
diff --git a/egs/alimeeting/sa-asr/local/fix_data_dir.sh b/egs/alimeeting/sa_asr/local/fix_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/fix_data_dir.sh
rename to egs/alimeeting/sa_asr/local/fix_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa_asr/local/format_wav_scp.py
similarity index 98%
rename from egs/alimeeting/sa-asr/local/format_wav_scp.py
rename to egs/alimeeting/sa_asr/local/format_wav_scp.py
index 1fd63d6..cb0eac3 100755
--- a/egs/alimeeting/sa-asr/local/format_wav_scp.py
+++ b/egs/alimeeting/sa_asr/local/format_wav_scp.py
@@ -11,7 +11,6 @@
import resampy
import soundfile
from tqdm import tqdm
-from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.read_text import read_2column_text
@@ -31,7 +30,6 @@
(3, 4, 5)
"""
- assert check_argument_types()
if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
return None
return tuple(map(int, integers.strip().split(",")))
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa_asr/local/format_wav_scp.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/format_wav_scp.sh
rename to egs/alimeeting/sa_asr/local/format_wav_scp.sh
diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
similarity index 97%
rename from egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
rename to egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
index c37abf9..859b72f 100644
--- a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
+++ b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
@@ -63,7 +63,7 @@
wav_scp_file = open(path+'/wav.scp', 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
- raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@@ -92,8 +92,8 @@
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
meeting_map = {}
for line in raw_meeting_scp:
- meeting = line.strip().split('\t')[0]
- wav_path = line.strip().split('\t')[1]
+ meeting = line.strip().split(' ')[0]
+ wav_path = line.strip().split(' ')[1]
wav = soundfile.read(wav_path)[0]
# take the first channel
if wav.ndim == 2:
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
similarity index 94%
rename from egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
index 18286b4..2a99b2b 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
+++ b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
@@ -9,7 +9,7 @@
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
- raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@@ -22,8 +22,8 @@
raw_wav_map = {}
for line in raw_meeting_scp:
- meeting = line.strip().split('\t')[0]
- wav_path = line.strip().split('\t')[1]
+ meeting = line.strip().split(' ')[0]
+ wav_path = line.strip().split(' ')[1]
raw_wav_map[meeting] = wav_path
spk_map = {}
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
similarity index 96%
rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
index 186f1de..ff65a1f 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
+++ b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
@@ -5,7 +5,7 @@
if __name__=="__main__":
- path = sys.argv[1] # dump2/raw/Train_Ali_far
+ path = sys.argv[1]
wav_scp_file = open(path+"/wav.scp", 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
@@ -29,7 +29,7 @@
line_list = line.strip().split(' ')
meeting = line_list[0].split('-')[0]
spk_id = line_list[0].split('-')[-1].split('_')[-1]
- spk = meeting+'_' + spk_id
+ spk = meeting + '_' + spk_id
global_spk_list.append(spk)
if meeting in meeting_map_tmp.keys():
meeting_map_tmp[meeting].append(spk)
diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
rename to egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
similarity index 94%
rename from egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
rename to egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
index d900bb1..488344f 100755
--- a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
+++ b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
@@ -30,8 +30,7 @@
meetingid_map = {}
for line in spk2utt:
spkid = line.strip().split(" ")[0]
- meeting_id_list = spkid.split("_")[:3]
- meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+ meeting_id = spkid.split("-")[0]
if meeting_id not in meetingid_map:
meetingid_map[meeting_id] = 1
else:
diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa_asr/local/process_text_id.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_text_id.py
rename to egs/alimeeting/sa_asr/local/process_text_id.py
diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa_asr/local/process_text_spk_merge.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_text_spk_merge.py
rename to egs/alimeeting/sa_asr/local/process_text_spk_merge.py
diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
rename to egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
diff --git a/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
rename to egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa_asr/local/text_format.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/text_format.pl
rename to egs/alimeeting/sa_asr/local/text_format.pl
diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa_asr/local/text_normalize.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/text_normalize.pl
rename to egs/alimeeting/sa_asr/local/text_normalize.pl
diff --git a/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
rename to egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
diff --git a/egs/alimeeting/sa-asr/local/validate_data_dir.sh b/egs/alimeeting/sa_asr/local/validate_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/validate_data_dir.sh
rename to egs/alimeeting/sa_asr/local/validate_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/validate_text.pl b/egs/alimeeting/sa_asr/local/validate_text.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/validate_text.pl
rename to egs/alimeeting/sa_asr/local/validate_text.pl
diff --git a/egs/alimeeting/sa_asr/path.sh b/egs/alimeeting/sa_asr/path.sh
new file mode 100755
index 0000000..83ae507
--- /dev/null
+++ b/egs/alimeeting/sa_asr/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
+export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
diff --git a/egs/alimeeting/sa_asr/run.sh b/egs/alimeeting/sa_asr/run.sh
new file mode 100755
index 0000000..43d0da1
--- /dev/null
+++ b/egs/alimeeting/sa_asr/run.sh
@@ -0,0 +1,435 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="6,7"
+gpu_num=2
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=8
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="data" #feature output dictionary
+exp_dir="exp"
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="1.0"
+min_wav_duration=0.1
+max_wav_duration=20
+profile_modes="cluster oracle"
+stage=7
+stop_stage=7
+
+# feature configuration
+feats_dim=80
+nj=32
+
+# data
+raw_data=
+data_url=
+
+# exp tag
+tag=""
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far Eval_Ali_far"
+test_2023="Test_2023_Ali_far_release"
+
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+inference_config=conf/decode_asr_rnn.yaml
+inference_sa_asr_model=valid.acc_spk.ave.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ ./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ ./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ ./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ # remove long/short data
+ for x in ${train_set} ${valid_set} ${test_sets}; do
+ cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
+ rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+ "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+ _min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
+ <"${feats_dir}/${x}/utt2num_samples" \
+ awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
+ >"${feats_dir}/${x}/utt2num_samples_rmls"
+ mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
+ <"${feats_dir}/${x}/wav.scp" \
+ utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples" \
+ >"${feats_dir}/${x}/wav.scp_rmls"
+ mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
+ <"${feats_dir}/${x}/text" \
+ awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
+ mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
+ local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
+ <"${feats_dir}/${x}/utt2spk_all_fifo" \
+ utils/filter_scp.pl "${feats_dir}/${x}/text" \
+ >"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
+ mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
+ done
+ for x in ${test_2023}; do
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+ "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+ cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
+ paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
+ cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Speaker profile and CMVN Generation"
+
+ mkdir -p "profile_log"
+ for x in "${train_set}" "${valid_set}" "${test_sets}"; do
+ # generate text_id spk2id
+ python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
+ echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
+ # generate text_id_train for sot
+ python local/process_text_id.py ${feats_dir}/${x}
+ echo "Successfully generate ${feats_dir}/${x}/text_id_train"
+ # generate oracle_embedding from single-speaker audio segment
+ echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
+ python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
+ echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
+ # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+ if [ "${x}" = "${train_set}" ]; then
+ python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
+ echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
+ else
+ python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
+ echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
+ fi
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
+ echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
+ python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+ echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+ fi
+ # compute CMVN
+ if [ "${x}" = "${train_set}" ]; then
+ local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
+ fi
+ done
+
+ for x in "${test_2023}"; do
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+ echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+ done
+fi
+
+token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: ASR Training"
+ asr_exp=${exp_dir}/${asr_model_dir}
+ mkdir -p ${asr_exp}
+ mkdir -p ${asr_exp}/log
+ INIT_FILE=${asr_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --model asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --token_type char \
+ --token_list $token_list \
+ --data_dir ${feats_dir} \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/${asr_model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+fi
+
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "SA-ASR training"
+ asr_exp=${exp_dir}/${asr_model_dir}
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ mkdir -p ${sa_asr_exp}
+ mkdir -p ${sa_asr_exp}/log
+ INIT_FILE=${sa_asr_exp}/ddp_init
+ if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
+ ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
+ ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
+ fi
+
+ if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+ # download xvector extractor model file
+ python local/download_xvector_model.py ${exp_dir}
+ echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+ fi
+
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --model sa_asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --unused_parameters true \
+ --token_type char \
+ --resume true \
+ --token_list $token_list \
+ --data_dir ${feats_dir} \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text,profile.scp,text_id_train" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
+ --output_dir ${exp_dir}/${sa_asr_model_dir} \
+ --config $sa_asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "stage 6: Inference test sets"
+ for x in ${test_sets}; do
+ for profile_mode in ${profile_modes}; do
+ echo "decoding ${x} with ${profile_mode} profile"
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${x}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if [ $profile_mode = "oracle" ]; then
+ profile_scp=${profile_mode}_profile_nopadding.scp
+ else
+ profile_scp=${profile_mode}_profile_infer.scp
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --nbest 1 \
+ --gpuid_list ${gpuid_list} \
+ --allow_variable_data_keys true \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ for f in token token_int score text text_id; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
+ sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
+ python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
+
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+ done
+ done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "stage 7: Inference test 2023"
+ for x in ${test_2023}; do
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${x}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --nbest 1 \
+ --gpuid_list ${gpuid_list} \
+ --allow_variable_data_keys true \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ for f in token token_int score text text_id; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+
+ python local/process_text_spk_merge.py ${_dir}
+
+ done
+fi
+
+
diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa_asr/utils
similarity index 100%
rename from egs/alimeeting/sa-asr/utils
rename to egs/alimeeting/sa_asr/utils
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/finetune.py
index 0393212..79fd34d 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/finetune.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/finetune.py
@@ -1,5 +1,4 @@
import os
-<<<<<<< HEAD
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
@@ -21,50 +20,17 @@
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
-=======
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
->>>>>>> main
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
-<<<<<<< HEAD
params = modelscope_args(model="damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch", data_path="./data")
params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
params.data_path = "./example_data/" # 鏁版嵁璺緞
params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 50 # 鏈�澶ц缁冭疆鏁�
+ params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-=======
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch"
- params["model_revision"] = None
->>>>>>> main
- modelscope_finetune(params)
+ modelscope_finetune(params)
\ No newline at end of file
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/infer.py
index a0f0965..da8859e 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch/infer.py
@@ -1,33 +1,3 @@
-<<<<<<< HEAD
-import os
-import shutil
-import argparse
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-def modelscope_infer(args):
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=args.model,
- output_dir=args.output_dir,
- batch_size=args.batch_size,
- param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
- )
- inference_pipeline(audio_in=args.audio_in)
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', type=str, default="damo/speech_UniASR_asr_2pass-tr-16k-common-vocab1582-pytorch")
- parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
- parser.add_argument('--output_dir', type=str, default="./results/")
- parser.add_argument('--decoding_mode', type=str, default="normal")
- parser.add_argument('--hotword_txt', type=str, default=None)
- parser.add_argument('--batch_size', type=int, default=64)
- parser.add_argument('--gpuid', type=str, default="0")
- args = parser.parse_args()
- modelscope_infer(args)
-=======
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -40,5 +10,4 @@
output_dir=output_dir,
)
rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
->>>>>>> main
+ print(rec_result)
\ No newline at end of file
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/README.md
index eff933e..9a84f9b 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/README.md
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/README.md
@@ -41,8 +41,7 @@
- Modify inference related parameters in `infer_after_finetune.py`
- <strong>output_dir:</strong> # result dir
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave
- .pb`
+ - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
- Then you can run the pipeline to finetune with:
```python
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/demo.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/demo.py
new file mode 100644
index 0000000..7ca7118
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/demo.py
@@ -0,0 +1,12 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+decoding_mode="normal" #fast, normal, offline
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online',
+ param_dict={"decoding_model": decoding_mode}
+)
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
deleted file mode 100644
index 876d51c..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import os
-import shutil
-from multiprocessing import Pool
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_core(output_dir, split_dir, njob, idx):
- output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
- gpu_id = (int(idx) - 1) // njob
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online",
- output_dir=output_dir_job,
- batch_size=1
- )
- audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
-
-
-def modelscope_infer(params):
- # prepare for multi-GPU decoding
- ngpu = params["ngpu"]
- njob = params["njob"]
- output_dir = params["output_dir"]
- if os.path.exists(output_dir):
- shutil.rmtree(output_dir)
- os.mkdir(output_dir)
- split_dir = os.path.join(output_dir, "split")
- os.mkdir(split_dir)
- nj = ngpu * njob
- wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
- with open(wav_scp_file) as f:
- lines = f.readlines()
- num_lines = len(lines)
- num_job_lines = num_lines // nj
- start = 0
- for i in range(nj):
- end = start + num_job_lines
- file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
- with open(file, "w") as f:
- if i == nj - 1:
- f.writelines(lines[start:])
- else:
- f.writelines(lines[start:end])
- start = end
-
- p = Pool(nj)
- for i in range(nj):
- p.apply_async(modelscope_infer_core,
- args=(output_dir, split_dir, njob, str(i + 1)))
- p.close()
- p.join()
-
- # combine decoding results
- best_recog_path = os.path.join(output_dir, "1best_recog")
- os.mkdir(best_recog_path)
- files = ["text", "token", "score"]
- for file in files:
- with open(os.path.join(best_recog_path, file), "w") as f:
- for i in range(nj):
- job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
- with open(job_file) as f_job:
- lines = f_job.readlines()
- f.writelines(lines)
-
- # If text exists, compute CER
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "text")
- compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
-
-
-if __name__ == "__main__":
- params = {}
- params["data_dir"] = "./data/test"
- params["output_dir"] = "./results"
- params["ngpu"] = 1
- params["njob"] = 1
- modelscope_infer(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.sh b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.sh
new file mode 100644
index 0000000..2d7a2da
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=2
+model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online"
+data_dir="./data/test"
+output_dir="./results"
+batch_size=1
+gpu_inference=false # whether to perform gpu decoding
+gpuid_list="-1" # set gpus, e.g., gpuid_list="0,1"
+njob=32 # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
+checkpoint_dir=
+checkpoint_name="valid.cer_ctc.ave.pb"
+decoding_mode="normal"
+
+. utils/parse_options.sh || exit 1;
+
+if ${gpu_inference} == "true"; then
+ nj=$(echo $gpuid_list | awk -F "," '{print NF}')
+else
+ nj=$njob
+ batch_size=1
+ gpuid_list=""
+ for JOB in $(seq ${nj}); do
+ gpuid_list=$gpuid_list"-1,"
+ done
+fi
+
+mkdir -p $output_dir/split
+split_scps=""
+for JOB in $(seq ${nj}); do
+ split_scps="$split_scps $output_dir/split/wav.$JOB.scp"
+done
+perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
+
+if [ -n "${checkpoint_dir}" ]; then
+ python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
+ model=${checkpoint_dir}/${model}
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
+ echo "Decoding ..."
+ gpuid_list_array=(${gpuid_list//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+ mkdir -p ${output_dir}/output.$JOB
+ python infer.py \
+ --model ${model} \
+ --audio_in ${output_dir}/split/wav.$JOB.scp \
+ --output_dir ${output_dir}/output.$JOB \
+ --batch_size ${batch_size} \
+ --gpuid ${gpuid} \
+ --decoding_mode ${decoding_mode}
+ }&
+ done
+ wait
+
+ mkdir -p ${output_dir}/1best_recog
+ for f in token score text; do
+ if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${nj}"); do
+ cat "${output_dir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${output_dir}/1best_recog/${f}"
+ fi
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
+ echo "Computing WER ..."
+ cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
+ cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
+ python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
+ tail -n 3 ${output_dir}/1best_recog/text.cer
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
+ echo "SpeechIO TIOBE textnorm"
+ echo "$0 --> Normalizing REF text ..."
+ ./utils/textnorm_zh.py \
+ --has_key --to_upper \
+ ${data_dir}/text \
+ ${output_dir}/1best_recog/ref.txt
+
+ echo "$0 --> Normalizing HYP text ..."
+ ./utils/textnorm_zh.py \
+ --has_key --to_upper \
+ ${output_dir}/1best_recog/text.proc \
+ ${output_dir}/1best_recog/rec.txt
+ grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt
+
+ echo "$0 --> computing WER/CER and alignment ..."
+ ./utils/error_rate_zh \
+ --tokenizer char \
+ --ref ${output_dir}/1best_recog/ref.txt \
+ --hyp ${output_dir}/1best_recog/rec_non_empty.txt \
+ ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt
+ rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt
+fi
+
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
deleted file mode 100644
index fd124ff..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import json
-import os
-import shutil
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_after_finetune(params):
- # prepare for decoding
- pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
- for file_name in params["required_files"]:
- if file_name == "configuration.json":
- with open(os.path.join(pretrained_model_path, file_name)) as f:
- config_dict = json.load(f)
- config_dict["model"]["am_model_name"] = params["decoding_model_name"]
- with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
- json.dump(config_dict, f, indent=4, separators=(',', ': '))
- else:
- shutil.copy(os.path.join(pretrained_model_path, file_name),
- os.path.join(params["output_dir"], file_name))
- decoding_path = os.path.join(params["output_dir"], "decode_results")
- if os.path.exists(decoding_path):
- shutil.rmtree(decoding_path)
- os.mkdir(decoding_path)
-
- # decoding
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=params["output_dir"],
- output_dir=decoding_path,
- batch_size=1
- )
- audio_in = os.path.join(params["data_dir"], "wav.scp")
- inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
-
- # computer CER if GT text is set
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/text")
- compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
-
-
-if __name__ == '__main__':
- params = {}
- params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online"
- params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data/test"
- params["decoding_model_name"] = "20epoch.pb"
- modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/utils b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/utils
new file mode 120000
index 0000000..2ac163f
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/utils
@@ -0,0 +1 @@
+../../../../egs/aishell/transformer/utils
\ No newline at end of file
diff --git a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py
index 5bc205c..f54399a 100644
--- a/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py
+++ b/egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.py
@@ -10,10 +10,9 @@
task=Tasks.auto_speech_recognition,
model=args.model,
output_dir=args.output_dir,
- batch_size=args.batch_size,
param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
)
- inference_pipeline(audio_in=args.audio_in)
+ inference_pipeline(audio_in=args.audio_in, batch_size_token=args.batch_size_token)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -22,7 +21,7 @@
parser.add_argument('--output_dir', type=str, default="./results/")
parser.add_argument('--decoding_mode', type=str, default="normal")
parser.add_argument('--hotword_txt', type=str, default=None)
- parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--batch_size_token', type=int, default=5000)
parser.add_argument('--gpuid', type=str, default="0")
args = parser.parse_args()
modelscope_infer(args)
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
index dc867b0..aa0db93 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
@@ -17,7 +17,7 @@
diar_model_config="sond.yaml",
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
- sv_model_revision="master",
+ sv_model_revision="v1.2.2",
)
# use audio_list as the input, where the first one is the record to be detected
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
index 3116f6d..581f7aa 100644
--- a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
@@ -4,8 +4,7 @@
inference_pipeline = pipeline(
task=Tasks.speech_timestamp,
model='damo/speech_timestamp_prediction-v1-16k-offline',
- model_revision='v1.1.0',
- output_dir=None)
+ model_revision='v1.1.0')
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 0e203c4..259a286 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1,66 +1,44 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
-import logging
-import sys
-import time
+
+import codecs
import copy
+import logging
import os
import re
-import codecs
import tempfile
-import requests
from pathlib import Path
+from typing import Any
+from typing import Dict
+from typing import List
from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
import numpy as np
+import requests
import torch
from packaging.version import parse as V
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.modules.beam_search.beam_search import BeamSearch
-# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
-from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.vad_infer import Speech2VadSegment
-from funasr.bin.punc_infer import Text2Punc
-from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-from funasr.tasks.asr import frontend_choices
+
class Speech2Text:
"""Speech2Text class
@@ -73,36 +51,35 @@
[(text, token, token_int, hypothesis object), ...]
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
):
- assert check_argument_types()
-
+
# 1. Build ASR model
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
@@ -110,16 +87,15 @@
if asr_train_args.frontend == 'wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
- from funasr.tasks.asr import frontend_choices
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
decoder = asr_model.decoder
-
+
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
@@ -127,24 +103,24 @@
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
-
+
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
+ lm, lm_train_args = build_model_from_file(
lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
-
+
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
-
+
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
from funasr.modules.beam_search.beam_search import BeamSearch
-
+
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
@@ -162,13 +138,13 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
+
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -180,7 +156,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
@@ -193,10 +169,10 @@
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
-
+
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
@@ -213,12 +189,11 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
-
+
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
+
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
@@ -229,48 +204,48 @@
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
-
+
# a. To device
batch = to_device(batch, device=self.device)
-
+
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
-
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
-
+
nbest_hyps = nbest_hyps[: self.nbest]
-
+
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
-
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
-
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
-
+
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
+
return results
+
class Speech2TextParaformer:
"""Speech2Text class
@@ -308,13 +283,11 @@
decoding_ind: int = 0,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -336,8 +309,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
@@ -398,6 +371,7 @@
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
+ self.cmvn_file = cmvn_file
# 6. [Optional] Build hotword list from str, local file or url
self.hotword_list = None
@@ -433,7 +407,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -466,18 +439,21 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
+ if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
+ NeatContextualParaformer):
if self.hotword_list:
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
+ pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
+ pre_token_length, hw_list=self.hotword_list)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
- pre_token_length) # test no bias cif2
+ pre_token_length) # test no bias cif2
results = []
b, n, d = decoder_out.size()
@@ -493,9 +469,9 @@
else:
if pre_token_length[i] == 0:
yseq = torch.tensor(
- [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device
+ [self.asr_model.sos] + [self.asr_model.eos], device=pre_acoustic_embeds.device
)
- score = torch.tensor(0.0, device=yseq.device)
+ score = torch.tensor(0.0, device=pre_acoustic_embeds.device)
else:
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
@@ -527,17 +503,53 @@
text = None
timestamp = []
if isinstance(self.asr_model, BiCifParaformer):
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
- us_peaks[i][:enc_len[i]*3],
- copy.copy(token),
- vad_offset=begin_time)
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i] * 3],
+ us_peaks[i][:enc_len[i] * 3],
+ copy.copy(token),
+ vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
-
- # assert check_return_type(results)
return results
def generate_hotwords_list(self, hotword_list_or_file):
+ def load_seg_dict(seg_dict_file):
+ seg_dict = {}
+ assert isinstance(seg_dict_file, str)
+ with open(seg_dict_file, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+ def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+ seg_dict = None
+ if self.cmvn_file is not None:
+ model_dir = os.path.dirname(self.cmvn_file)
+ seg_dict_file = os.path.join(model_dir, 'seg_dict')
+ if os.path.exists(seg_dict_file):
+ seg_dict = load_seg_dict(seg_dict_file)
+ else:
+ seg_dict = None
# for None
if hotword_list_or_file is None:
hotword_list = None
@@ -549,8 +561,11 @@
with codecs.open(hotword_list_or_file, 'r') as fin:
for line in fin.readlines():
hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append(self.converter.tokens2ids(hw_list))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -570,8 +585,11 @@
with codecs.open(hotword_list_or_file, 'r') as fin:
for line in fin.readlines():
hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append(self.converter.tokens2ids(hw_list))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -583,13 +601,17 @@
hotword_str_list = []
for hw in hotword_list_or_file.strip().split():
hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hw_list = hw.strip().split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_list.append(self.converter.tokens2ids(hw_list))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Hotword list: {}.".format(hotword_str_list))
else:
hotword_list = None
return hotword_list
+
class Speech2TextParaformerOnline:
"""Speech2Text class
@@ -626,13 +648,11 @@
hotword_list_or_file: str = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -654,8 +674,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
@@ -747,7 +767,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
results = []
cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -789,7 +808,7 @@
enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
- pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
+ pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
if torch.max(pre_token_length) < 1:
return []
decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
@@ -839,11 +858,11 @@
postprocessed_result += item + " "
else:
postprocessed_result += item
-
+
results.append(postprocessed_result)
- # assert check_return_type(results)
return results
+
class Speech2TextUniASR:
"""Speech2Text class
@@ -882,13 +901,11 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -914,8 +931,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, device, "lm"
)
scorers["lm"] = lm.lm
@@ -1007,7 +1024,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -1075,9 +1091,8 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
-
+
class Speech2TextMFCCA:
"""Speech2Text class
@@ -1090,45 +1105,43 @@
[(text, token, token_int, hypothesis object), ...]
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- **kwargs,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ **kwargs,
):
- assert check_argument_types()
-
+
# 1. Build ASR model
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
-
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
decoder = asr_model.decoder
-
+
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
@@ -1136,11 +1149,11 @@
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
-
+
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
lm.to(device)
scorers["lm"] = lm.lm
@@ -1148,11 +1161,11 @@
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
-
+
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
-
+
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
@@ -1176,7 +1189,7 @@
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -1188,7 +1201,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
@@ -1200,10 +1213,10 @@
self.device = device
self.dtype = dtype
self.nbest = nbest
-
+
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
@@ -1220,7 +1233,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1231,46 +1243,45 @@
# lenghts: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
-
+
# a. To device
batch = to_device(batch, device=self.device)
-
+
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
-
+
assert len(enc) == 1, len(enc)
-
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
-
+
nbest_hyps = nbest_hyps[: self.nbest]
-
+
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
-
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
-
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
-
+
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
+
return results
@@ -1298,45 +1309,43 @@
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- beam_search_config: Dict[str, Any] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- beam_size: int = 5,
- dtype: str = "float32",
- lm_weight: float = 1.0,
- quantize_asr_model: bool = False,
- quantize_modules: List[str] = None,
- quantize_dtype: str = "qint8",
- nbest: int = 1,
- streaming: bool = False,
- simu_streaming: bool = False,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- display_partial_hypotheses: bool = False,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ beam_search_config: Dict[str, Any] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ beam_size: int = 5,
+ dtype: str = "float32",
+ lm_weight: float = 1.0,
+ quantize_asr_model: bool = False,
+ quantize_modules: List[str] = None,
+ quantize_dtype: str = "qint8",
+ nbest: int = 1,
+ streaming: bool = False,
+ simu_streaming: bool = False,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ display_partial_hypotheses: bool = False,
) -> None:
"""Construct a Speech2Text object."""
super().__init__()
-
- assert check_argument_types()
- from funasr.tasks.asr import ASRTransducerTask
- asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
-
+
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
+
if quantize_asr_model:
if quantize_modules is not None:
if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
@@ -1344,36 +1353,36 @@
"Only 'Linear' and 'LSTM' modules are currently supported"
" by PyTorch and in --quantize_modules"
)
-
+
q_config = set([getattr(torch.nn, q) for q in quantize_modules])
else:
q_config = {torch.nn.Linear}
-
+
if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch"
" version < 1.5.0. Switching to qint8 dtype instead."
)
q_dtype = getattr(torch, quantize_dtype)
-
+
asr_model = torch.quantization.quantize_dynamic(
asr_model, q_config, dtype=q_dtype
).eval()
else:
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
lm_scorer = lm.lm
else:
lm_scorer = None
-
+
# 4. Build BeamSearch object
if beam_search_config is None:
beam_search_config = {}
-
+
beam_search = BeamSearchTransducer(
asr_model.decoder,
asr_model.joint_network,
@@ -1383,14 +1392,14 @@
nbest=nbest,
**beam_search_config,
)
-
+
token_list = asr_model.token_list
-
+
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -1402,60 +1411,60 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.device = device
self.dtype = dtype
self.nbest = nbest
-
+
self.converter = converter
self.tokenizer = tokenizer
-
+
self.beam_search = beam_search
self.streaming = streaming
self.simu_streaming = simu_streaming
self.chunk_size = max(chunk_size, 0)
self.left_context = left_context
self.right_context = max(right_context, 0)
-
+
if not streaming or chunk_size == 0:
self.streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
-
+
if not simu_streaming or chunk_size == 0:
self.simu_streaming = False
self.asr_model.encoder.dynamic_chunk_training = False
-
+
self.frontend = frontend
self.window_size = self.chunk_size + self.right_context
-
+
if self.streaming:
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
-
+
self.last_chunk_length = (
- self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
)
self.reset_inference_cache()
-
+
def reset_inference_cache(self) -> None:
"""Reset Speech2Text parameters."""
self.frontend_cache = None
-
+
self.asr_model.encoder.reset_streaming_cache(
self.left_context, device=self.device
)
self.beam_search.reset_inference_cache()
-
+
self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
+
@torch.no_grad()
def streaming_decode(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- is_final: bool = True,
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ is_final: bool = True,
) -> List[HypothesisTransducer]:
"""Speech2Text streaming call.
Args:
@@ -1473,13 +1482,13 @@
)
speech = torch.cat([speech, pad],
dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
-
+
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
+
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
+
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out = self.asr_model.encoder.chunk_forward(
@@ -1491,14 +1500,14 @@
right_context=self.right_context,
)
nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
-
+
self.num_processed_frames += self.chunk_size
-
+
if is_final:
self.reset_inference_cache()
-
+
return nbest_hyps
-
+
@torch.no_grad()
def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
"""Speech2Text call.
@@ -1507,30 +1516,29 @@
Returns:
nbest_hypothesis: N-best hypothesis.
"""
- assert check_argument_types()
-
+
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
+
if self.frontend is not None:
speech = torch.unsqueeze(speech, axis=0)
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
+ else:
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
+
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
+
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
self.right_context)
nbest_hyps = self.beam_search(enc_out[0])
-
+
return nbest_hyps
-
+
@torch.no_grad()
def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
"""Speech2Text call.
@@ -1539,8 +1547,7 @@
Returns:
nbest_hypothesis: N-best hypothesis.
"""
- assert check_argument_types()
-
+
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1548,19 +1555,19 @@
speech = torch.unsqueeze(speech, axis=0)
speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
+ else:
feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
+
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
-
+
enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
-
+
nbest_hyps = self.beam_search(enc_out[0])
-
+
return nbest_hyps
-
+
def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
"""Build partial or final results from the hypotheses.
Args:
@@ -1569,47 +1576,20 @@
results: Results containing different representation for the hypothesis.
"""
results = []
-
+
for hyp in nbest_hyps:
token_int = list(filter(lambda x: x != 0, hyp.yseq))
-
+
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
-
+
+
return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ) -> Speech2Text:
- """Build Speech2Text instance from the pretrained model.
- Args:
- model_tag: Model tag of the pretrained models.
- Return:
- : Speech2Text instance.
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
@@ -1623,53 +1603,53 @@
[(text, token, token_int, hypothesis object), ...]
"""
-
+
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
):
- assert check_argument_types()
-
+
# 1. Build ASR model
- from funasr.tasks.sa_asr import ASRTask
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend == 'wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ from funasr.tasks.sa_asr import frontend_choices
+ if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
+
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
-
+
decoder = asr_model.decoder
-
+
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
@@ -1677,24 +1657,24 @@
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
-
+
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, None, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
-
+
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
-
+
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
-
+
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
@@ -1712,13 +1692,13 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
+
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
-
+
if token_type is None:
tokenizer = None
elif token_type == "bpe":
@@ -1730,7 +1710,7 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
+
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
@@ -1743,11 +1723,11 @@
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
-
+
@torch.no_grad()
def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
- profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
+ profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
) -> List[
Tuple[
Optional[str],
@@ -1765,15 +1745,14 @@
text, text_id, token, token_int, hyp
"""
- assert check_argument_types()
-
+
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
+
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
-
+
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
@@ -1784,10 +1763,10 @@
feats_len = speech_lengths
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
-
+
# a. To device
batch = to_device(batch, device=self.device)
-
+
# b. Forward Encoder
asr_enc, _, spk_enc = self.asr_model.encode(**batch)
if isinstance(asr_enc, tuple):
@@ -1796,30 +1775,30 @@
spk_enc = spk_enc[0]
assert len(asr_enc) == 1, len(asr_enc)
assert len(spk_enc) == 1, len(spk_enc)
-
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
-
+
nbest_hyps = nbest_hyps[: self.nbest]
-
+
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
-
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1: last_pos]
else:
token_int = hyp.yseq[1: last_pos].tolist()
-
+
spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
-
+
token_ori = self.converter.ids2tokens(token_int)
text_ori = self.tokenizer.tokens2text(token_ori)
-
+
text_ori_spklist = text_ori.split('$')
cur_index = 0
spk_choose = []
@@ -1831,32 +1810,31 @@
spk_weights_local = spk_weights_local.mean(dim=0)
spk_choose_local = spk_weights_local.argmax(-1)
spk_choose.append(spk_choose_local.item() + 1)
-
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
-
+
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
-
+
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
-
+
text_spklist = text.split('$')
assert len(spk_choose) == len(text_spklist)
-
+
spk_list = []
for i in range(len(text_spklist)):
text_split = text_spklist[i]
n = len(text_split)
spk_list.append(str(spk_choose[i]) * n)
-
+
text_id = '$'.join(spk_list)
-
+
assert len(text) == len(text_id)
-
+
results.append((text, text_id, token, token_int, hyp))
-
- assert check_return_type(results)
+
return results
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f84212d..37a5fe4 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,111 +7,78 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
from pathlib import Path
+from typing import Dict
+from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-import yaml
+
import numpy as np
import torch
import torchaudio
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearch
-# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+import soundfile
+import yaml
+from funasr.bin.asr_infer import Speech2Text
+from funasr.bin.asr_infer import Speech2TextMFCCA
+from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
+from funasr.bin.asr_infer import Speech2TextSAASR
+from funasr.bin.asr_infer import Speech2TextTransducer
+from funasr.bin.asr_infer import Speech2TextUniASR
+from funasr.bin.punc_infer import Text2Punc
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
+from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import asr_utils, postprocess_utils
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-from funasr.bin.asr_infer import Speech2Text
-from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
-from funasr.bin.asr_infer import Speech2TextUniASR
-from funasr.bin.asr_infer import Speech2TextMFCCA
-from funasr.bin.vad_infer import Speech2VadSegment
-from funasr.bin.punc_infer import Text2Punc
-from funasr.bin.tp_infer import Speech2Timestamp
-from funasr.bin.asr_infer import Speech2TextTransducer
-from funasr.bin.asr_infer import Speech2TextSAASR
+
def inference_asr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -120,23 +87,23 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -160,7 +127,7 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -173,20 +140,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -197,14 +162,14 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
@@ -212,19 +177,19 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -233,67 +198,66 @@
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
-
+
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}\n".format(text))
return asr_result_list
-
+
return _forward
def inference_paraformer(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ output_dir: Optional[str] = None,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
-
+
if kwargs.get("device", None) == "cpu":
ngpu = 0
if ngpu >= 1 and torch.cuda.is_available():
@@ -301,10 +265,10 @@
else:
device = "cpu"
batch_size = 1
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -326,9 +290,9 @@
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
-
+
speech2text = Speech2TextParaformer(**speech2text_kwargs)
-
+
if timestamp_model_file is not None:
speechtext2timestamp = Speech2Timestamp(
timestamp_cmvn_file=cmvn_file,
@@ -337,16 +301,16 @@
)
else:
speechtext2timestamp = None
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
):
-
+
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
@@ -354,30 +318,28 @@
hotword_list_or_file = kwargs['hotword']
if hotword_list_or_file is not None or 'hotword' in kwargs:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
-
+
forward_time_total = 0.0
length_total = 0.0
finish_count = 0
@@ -390,17 +352,17 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
+
logging.info("decoding, utt_id: {}".format(keys))
# N-best list of (text, token, token_int, hyp_object)
-
+
time_beg = time.time()
results = speech2text(**batch)
if len(results) < 1:
@@ -416,10 +378,10 @@
100 * forward_time / (
length * lfr_factor))
logging.info(rtf_cur)
-
+
for batch_id in range(_bs):
result = [results[batch_id][:-2]]
-
+
key = keys[batch_id]
for n, result in zip(range(1, nbest + 1), result):
text, token, token_int, hyp = result[0], result[1], result[2], result[3]
@@ -438,13 +400,13 @@
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["rtf"][key] = rtf_cur
-
+
if text is not None:
if use_timestamp and timestamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
@@ -465,7 +427,7 @@
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
-
+
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
forward_time_total,
@@ -475,74 +437,73 @@
if writer is not None:
ibest_writer["rtf"]["rtf_avf"] = rtf_avg
return asr_result_list
-
+
return _forward
def inference_paraformer_vad_punc(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- vad_infer_config: Optional[str] = None,
- vad_model_file: Optional[str] = None,
- vad_cmvn_file: Optional[str] = None,
- time_stamp_writer: bool = True,
- punc_infer_config: Optional[str] = None,
- punc_model_file: Optional[str] = None,
- outputs_dict: Optional[bool] = True,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ vad_infer_config: Optional[str] = None,
+ vad_model_file: Optional[str] = None,
+ vad_cmvn_file: Optional[str] = None,
+ time_stamp_writer: bool = True,
+ punc_infer_config: Optional[str] = None,
+ punc_model_file: Optional[str] = None,
+ outputs_dict: Optional[bool] = True,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
else:
hotword_list_or_file = None
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2vadsegment
speech2vadsegment_kwargs = dict(
vad_infer_config=vad_infer_config,
@@ -553,7 +514,7 @@
)
# logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
+
# 3. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -579,12 +540,12 @@
text2punc = None
if punc_model_file is not None:
text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
+
if output_dir is not None:
writer = DatadirWriter(output_dir)
ibest_writer = writer[f"1best_recog"]
ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -592,43 +553,41 @@
param_dict: dict = None,
**kwargs,
):
-
+
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
-
+
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
-
+
batch_size_token = kwargs.get("batch_size_token", 6000)
print("batch_size_token: ", batch_size_token)
-
+
if speech2text.hotword_list is None:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=1,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
if param_dict is not None:
use_timestamp = param_dict.get('use_timestamp', True)
else:
use_timestamp = True
-
+
finish_count = 0
file_count = 1
lfr_factor = 6
@@ -639,7 +598,7 @@
if output_path is not None:
writer = DatadirWriter(output_path)
ibest_writer = writer[f"1best_recog"]
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
@@ -648,21 +607,27 @@
beg_vad = time.time()
vad_results = speech2vadsegment(**batch)
end_vad = time.time()
- print("time cost vad: ", end_vad-beg_vad)
+ print("time cost vad: ", end_vad - beg_vad)
_, vadsegments = vad_results[0], vad_results[1][0]
-
+
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
+
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
+
batch_size_token_ms = batch_size_token*60
+ if speech2text.device == "cpu":
+ batch_size_token_ms = 0
+ batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+
batch_size_token_ms_cum = 0
beg_idx = 0
for j, _ in enumerate(range(0, n)):
batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
- if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
+ if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
+ 0]) < batch_size_token_ms:
continue
batch_size_token_ms_cum = 0
end_idx = j + 1
@@ -675,11 +640,11 @@
results = speech2text(**batch)
end_asr = time.time()
print("time cost asr: ", end_asr - beg_asr)
-
+
if len(results) < 1:
results = [["", [], [], [], [], [], []]]
results_sorted.extend(results)
-
+
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
@@ -695,12 +660,12 @@
t[1] += vadsegments[j][0]
result[4] += restored_data[j][4]
# result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
+
key = keys[0]
# result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = result[4] if len(result[4]) > 0 else None
-
+
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
@@ -714,23 +679,23 @@
postprocessed_result[2]
else:
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
+
text_postprocessed_punc = text_postprocessed
punc_id_list = []
if len(word_lists) > 0 and text2punc is not None:
beg_punc = time.time()
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
end_punc = time.time()
- print("time cost punc: ", end_punc-beg_punc)
-
+ print("time cost punc: ", end_punc - beg_punc)
+
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
item['text_postprocessed'] = text_postprocessed
if time_stamp_postprocessed != "":
item['time_stamp'] = time_stamp_postprocessed
-
+
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
+
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
@@ -743,11 +708,12 @@
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
+
logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
return asr_result_list
-
+
return _forward
+
def inference_paraformer_online(
maxlenratio: float,
@@ -779,7 +745,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
@@ -848,7 +813,7 @@
data = yaml.load(f, Loader=yaml.Loader)
return data
- def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
+ def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
if len(cache) > 0:
return cache
config = _read_yaml(asr_train_config)
@@ -864,14 +829,15 @@
return cache
- def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
+ def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
if len(cache) > 0:
config = _read_yaml(asr_train_config)
enc_output_size = config["encoder_conf"]["output_size"]
feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+ "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
@@ -893,7 +859,13 @@
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ try:
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ except:
+ raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+ if raw_inputs.ndim == 2:
+ raw_inputs = raw_inputs[:, 0]
+ raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
@@ -916,7 +888,7 @@
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
sample_offset = 0
speech_length = raw_inputs.shape[1]
- stride_size = chunk_size[1] * 960
+ stride_size = chunk_size[1] * 960
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
@@ -945,42 +917,41 @@
def inference_uniasr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- ngram_file: Optional[str] = None,
- cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ ngram_file: Optional[str] = None,
+ cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ token_num_relax: int = 1,
+ decoding_ind: int = 0,
+ decoding_mode: str = "model1",
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -989,17 +960,17 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
if param_dict is not None and "decoding_model" in param_dict:
if param_dict["decoding_model"] == "fast":
decoding_ind = 0
@@ -1012,10 +983,10 @@
decoding_mode = "model2"
else:
raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -1042,7 +1013,7 @@
decoding_mode=decoding_mode,
)
speech2text = Speech2TextUniASR(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1055,19 +1026,17 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -1078,14 +1047,14 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
@@ -1093,7 +1062,7 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
logging.info(f"Utterance: {key}")
@@ -1101,12 +1070,12 @@
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -1116,42 +1085,41 @@
if writer is not None:
ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
-
+
return _forward
def inference_mfcca(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -1160,20 +1128,20 @@
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -1197,7 +1165,7 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2TextMFCCA(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1210,20 +1178,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
fs=fs,
mc=True,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -1234,14 +1200,14 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
@@ -1249,19 +1215,19 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -1271,42 +1237,43 @@
if writer is not None:
ibest_writer["text"][key] = text
return asr_result_list
-
+
return _forward
+
def inference_transducer(
- output_dir: str,
- batch_size: int,
- dtype: str,
- beam_size: int,
- ngpu: int,
- seed: int,
- lm_weight: float,
- nbest: int,
- num_workers: int,
- log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str],
- beam_search_config: Optional[dict],
- lm_train_config: Optional[str],
- lm_file: Optional[str],
- model_tag: Optional[str],
- token_type: Optional[str],
- bpemodel: Optional[str],
- key_file: Optional[str],
- allow_variable_data_keys: bool,
- quantize_asr_model: Optional[bool],
- quantize_modules: Optional[List[str]],
- quantize_dtype: Optional[str],
- streaming: Optional[bool],
- simu_streaming: Optional[bool],
- chunk_size: Optional[int],
- left_context: Optional[int],
- right_context: Optional[int],
- display_partial_hypotheses: bool,
- **kwargs,
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
) -> None:
"""Transducer model inference.
Args:
@@ -1340,7 +1307,6 @@
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
- assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
@@ -1387,7 +1353,7 @@
model_tag=model_tag,
**speech2text_kwargs,
)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1396,131 +1362,123 @@
**kwargs,
):
# 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(
- speech2text.asr_train_args, False
- ),
- collate_fn=ASRTask.build_collate_fn(
- speech2text.asr_train_args, False
- ),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
# 4 .Start for-loop
with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
-
+
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
assert len(batch.keys()) == 1
-
+
try:
if speech2text.streaming:
speech = batch["speech"]
-
+
_steps = len(speech) // speech2text._ctx
_end = 0
for i in range(_steps):
_end = (i + 1) * speech2text._ctx
-
+
speech2text.streaming_decode(
- speech[i * speech2text._ctx : _end], is_final=False
+ speech[i * speech2text._ctx: _end], is_final=False
)
-
+
final_hyps = speech2text.streaming_decode(
- speech[_end : len(speech)], is_final=True
+ speech[_end: len(speech)], is_final=True
)
elif speech2text.simu_streaming:
final_hyps = speech2text.simu_streaming_decode(**batch)
else:
final_hyps = speech2text(**batch)
-
+
results = speech2text.hypotheses_to_results(final_hyps)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
results = [[" ", ["<space>"], [2], hyp]] * nbest
-
+
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
ibest_writer = writer[f"{n}best_recog"]
-
+
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
-
+
if text is not None:
ibest_writer["text"][key] = text
-
return _forward
def inference_sa_asr(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
@@ -1544,7 +1502,7 @@
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2TextSAASR(**speech2text_kwargs)
-
+
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
@@ -1557,20 +1515,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
@@ -1581,7 +1537,7 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
@@ -1595,20 +1551,20 @@
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["sil"], [2], hyp]] * nbest
-
+
# Only supporting batch_size==1
key = keys[0]
for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
-
+
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
ibest_writer["text_id"][key] = text_id
-
+
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
@@ -1617,12 +1573,12 @@
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
-
+
logging.info("uttid: {}".format(key))
logging.info("text predictions: {}".format(text))
logging.info("text_id predictions: {}\n".format(text_id))
return asr_result_list
-
+
return _forward
@@ -1660,7 +1616,7 @@
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-
+
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
@@ -1670,7 +1626,7 @@
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
-
+
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
@@ -1703,7 +1659,7 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -1725,7 +1681,7 @@
default=False,
help="MultiChannel input",
)
-
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@@ -1788,7 +1744,7 @@
default={},
help="The keyword arguments for transducer beam search.",
)
-
+
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
@@ -1835,7 +1791,7 @@
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
-
+
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
@@ -1860,7 +1816,7 @@
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
)
-
+
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
@@ -1918,7 +1874,6 @@
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
-
if __name__ == "__main__":
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
index 4460e3d..3efa641 100755
--- a/funasr/bin/diar_infer.py
+++ b/funasr/bin/diar_infer.py
@@ -1,41 +1,27 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
import os
-import sys
+from collections import OrderedDict
from pathlib import Path
from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from collections import OrderedDict
import numpy as np
-import soundfile
import torch
-from torch.nn import functional as F
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.diar import DiarTask
-from funasr.tasks.diar import EENDOLADiarTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
from scipy.ndimage import median_filter
-from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
+from torch.nn import functional as F
+
from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.tasks.diar import DiarTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.torch_utils.device_funcs import to_device
+from funasr.utils.misc import statistic_model_parameters
+
class Speech2DiarizationEEND:
"""Speech2Diarlization class
@@ -58,13 +44,14 @@
device: str = "cpu",
dtype: str = "float32",
):
- assert check_argument_types()
# 1. Build Diarization model
- diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
+ diar_model, diar_train_args = build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
- device=device
+ device=device,
+ task_name="diar",
+ mode="eend-ola",
)
frontend = None
if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
@@ -99,7 +86,6 @@
diarization results
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -117,36 +103,6 @@
results = self.diar_model.estimate_sequential(**batch)
return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ):
- """Build Speech2Diarization instance from the pretrained model.
-
- Args:
- model_tag (Optional[str]): Model tag of the pretrained models.
- Currently, the tags of espnet_model_zoo are supported.
-
- Returns:
- Speech2Diarization: Speech2Diarization instance.
-
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2DiarizationEEND(**kwargs)
class Speech2DiarizationSOND:
@@ -174,13 +130,14 @@
smooth_size: int = 83,
dur_threshold: float = 10,
):
- assert check_argument_types()
# TODO: 1. Build Diarization model
- diar_model, diar_train_args = DiarTask.build_model_from_file(
+ diar_model, diar_train_args = build_model_from_file(
config_file=diar_train_config,
model_file=diar_model_file,
- device=device
+ device=device,
+ task_name="diar",
+ mode="sond",
)
logging.info("diar_model: {}".format(diar_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
@@ -248,7 +205,7 @@
ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
logits_idx = F.upsample(
logits_idx.unsqueeze(1).float(),
- size=(ut, ),
+ size=(ut,),
mode="nearest",
).squeeze(1).long()
logits_idx = logits_idx[0].tolist()
@@ -268,7 +225,7 @@
if spk not in results:
results[spk] = []
if dur > self.dur_threshold:
- results[spk].append((st, st+dur))
+ results[spk].append((st, st + dur))
# sort segments in start time ascending
for spk in results:
@@ -292,7 +249,6 @@
diarization results for each speaker
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -314,37 +270,3 @@
results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
return results, pse_labels
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ):
- """Build Speech2Xvector instance from the pretrained model.
-
- Args:
- model_tag (Optional[str]): Model tag of the pretrained models.
- Currently, the tags of espnet_model_zoo are supported.
-
- Returns:
- Speech2Xvector: Speech2Xvector instance.
-
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2DiarizationSOND(**kwargs)
-
-
-
-
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index e0d900e..03c9659 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -1,5 +1,5 @@
+# !/usr/bin/env python3
# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -8,47 +8,27 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
-from collections import OrderedDict
import numpy as np
import soundfile
import torch
-from torch.nn import functional as F
-from typeguard import check_argument_types
-from typeguard import check_return_type
from scipy.signal import medfilt
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.diar import DiarTask
-from funasr.tasks.diar import EENDOLADiarTask
-from funasr.torch_utils.device_funcs import to_device
+
+from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+from funasr.datasets.iterable_dataset import load_bytes
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from scipy.ndimage import median_filter
-from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
-from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+
def inference_sond(
diar_train_config: str,
@@ -71,7 +51,6 @@
mode: str = "sond",
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -94,7 +73,8 @@
set_all_random_seed(seed)
# 2a. Build speech2xvec [Optional]
- if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+ if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[
+ "extract_profile"]:
assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
sv_train_config = param_dict["sv_train_config"]
@@ -139,7 +119,7 @@
rst = []
mid = uttid.rsplit("-", 1)[0]
for key in results:
- results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
+ results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
if out_format == "vad":
for spk, segs in results.items():
rst.append("{} {}".format(spk, segs))
@@ -176,7 +156,7 @@
example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
for x in example]
speech = example[0]
- logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
+ logging.info("Extracting profiles for {} waveforms".format(len(example) - 1))
profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
profile = torch.cat(profile, dim=0)
yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
@@ -186,16 +166,15 @@
raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
else:
# 3. Build data-iterator
- loader = DiarTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="diar",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
+ use_collate_fn=False,
)
# 7. Start for-loop
@@ -235,6 +214,7 @@
return _forward
+
def inference_eend(
diar_train_config: str,
diar_model_file: str,
@@ -251,7 +231,6 @@
param_dict: Optional[dict] = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -306,16 +285,14 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
- loader = EENDOLADiarTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="diar",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
- collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
# 3. Start for-loop
@@ -362,8 +339,6 @@
return _forward
-
-
def inference_launch(mode, **kwargs):
if mode == "sond":
return inference_sond(mode=mode, **kwargs)
@@ -386,6 +361,7 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index 1d99fce..236a923 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,40 +7,24 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
from typing import Any
from typing import List
+from typing import Optional
+from typing import Union
import numpy as np
import torch
from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import float_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
@@ -48,42 +32,41 @@
def inference_lm(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float] = 10,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build Model
- model, train_args = LMTask.build_model_from_file(
- train_config, model_file, device)
+ model, train_args = build_model_from_file(
+ train_config, model_file, None, device, "lm")
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
logging.info(f"Model:\n{model}")
-
+
preprocessor = LMPreprocessor(
train=False,
token_type=train_args.token_type,
@@ -96,12 +79,12 @@
split_with_space=split_with_space,
seg_dict_file=seg_dict_file
)
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
results = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
@@ -109,7 +92,7 @@
writer = DatadirWriter(output_path)
else:
writer = None
-
+
if raw_inputs != None:
line = raw_inputs.strip()
key = "lm demo"
@@ -121,7 +104,7 @@
batch['text'] = line
if preprocessor != None:
batch = preprocessor(key, batch)
-
+
# Force data-precision
for name in batch:
value = batch[name]
@@ -138,11 +121,11 @@
else:
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
batch[name] = value
-
+
batch["text_lengths"] = torch.from_numpy(
np.array([len(batch["text"])], dtype='int32'))
batch["text"] = np.expand_dims(batch["text"], axis=0)
-
+
with torch.no_grad():
batch = to_device(batch, device)
if ngpu <= 1:
@@ -173,7 +156,7 @@
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
-
+
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
@@ -189,22 +172,20 @@
if writer is not None:
writer["ppl"][key + ":\n"] = ppl_out
results.append(item)
-
+
return results
-
+
# 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="lm",
+ preprocess_args=train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=preprocessor,
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
# 4. Start for-loop
total_nll = 0.0
total_ntokens = 0
@@ -214,7 +195,7 @@
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
+
ppl_out_batch = ""
with torch.no_grad():
batch = to_device(batch, device)
@@ -247,7 +228,7 @@
word_nll=round(word_nll.item(), 8)
)
pre_word = cur_word
-
+
sent_nll_mean = sent_nll.mean().cpu().numpy()
sent_nll_sum = sent_nll.sum().cpu().numpy()
if log_base is None:
@@ -265,9 +246,9 @@
writer["ppl"][key + ":\n"] = ppl_out
writer["utt2nll"][key] = str(utt2nll)
results.append(item)
-
+
ppl_out_all += ppl_out_batch
-
+
assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
# nll: (B, L) -> (B,)
nll = nll.detach().cpu().numpy().sum(1)
@@ -275,12 +256,12 @@
lengths = lengths.detach().cpu().numpy()
total_nll += nll.sum()
total_ntokens += lengths.sum()
-
+
if log_base is None:
ppl = np.exp(total_nll / total_ntokens)
else:
ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
+
avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
total_nll=round(-total_nll.item(), 4),
total_ppl=round(ppl.item(), 4)
@@ -290,9 +271,9 @@
if writer is not None:
writer["ppl"]["AVG PPL : "] = avg_ppl
results.append(item)
-
+
return results
-
+
return _forward
@@ -302,7 +283,8 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
-
+
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
@@ -407,4 +389,3 @@
if __name__ == "__main__":
main()
-
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 4b6cd27..ac96811 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -1,46 +1,32 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
-import logging
-from pathlib import Path
-import sys
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Any
-from typing import List
import numpy as np
import torch
-from typeguard import check_argument_types
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.punctuation import PunctuationTask
+from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
):
# Build Model
- model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -144,16 +130,16 @@
class Text2PuncVADRealtime:
-
+
def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
):
# Build Model
- model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
self.device = device
# Wrape model to make model.nll() data-parallel
self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -178,7 +164,7 @@
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
-
+
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
if cache is not None and len(cache) > 0:
@@ -215,7 +201,7 @@
if indices.size()[0] != 1:
punctuations = torch.squeeze(indices)
assert punctuations.size()[0] == len(mini_sentence)
-
+
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
@@ -226,7 +212,7 @@
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
last_comma_index = i
-
+
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
@@ -235,11 +221,11 @@
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
-
+
punctuations_np = punctuations.cpu().numpy()
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
-
+
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
@@ -256,7 +242,7 @@
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
-
+
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
@@ -267,5 +253,3 @@
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
return sentence_out, sentence_punc_list_out, cache_out
-
-
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index 7f60f81..5d917f5 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,57 +7,36 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.types import float_or_none
-
-import argparse
-import logging
from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
from typing import Any
from typing import List
+from typing import Optional
+from typing import Union
-import numpy as np
import torch
-from typeguard import check_argument_types
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.punctuation import PunctuationTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.datasets.preprocessor import split_to_mini_sentence
-from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
+
def inference_punc(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -73,11 +52,11 @@
text2punc = Text2Punc(train_config, model_file, device)
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
):
results = []
split_size = 20
@@ -121,22 +100,22 @@
return _forward
+
def inference_punc_vad_realtime(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- #cache: list,
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ # cache: list,
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
@@ -150,11 +129,11 @@
text2punc = Text2PuncVADRealtime(train_config, model_file, device)
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
):
results = []
split_size = 10
@@ -177,7 +156,6 @@
return _forward
-
def inference_launch(mode, **kwargs):
if mode == "punc":
return inference_punc(**kwargs)
@@ -186,6 +164,7 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -267,7 +246,6 @@
kwargs.pop("njob", None)
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
if __name__ == "__main__":
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 1517bfa..346440a 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -1,35 +1,22 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-import os
-import sys
from pathlib import Path
from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
-from kaldiio import WriteHelper
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
+
class Speech2Xvector:
"""Speech2Xvector class
@@ -53,13 +40,15 @@
streaming: bool = False,
embedding_node: str = "resnet1_dense",
):
- assert check_argument_types()
# TODO: 1. Build SV model
- sv_model, sv_train_args = SVTask.build_model_from_file(
+ sv_model, sv_train_args = build_model_from_file(
config_file=sv_train_config,
model_file=sv_model_file,
- device=device
+ cmvn_file=None,
+ device=device,
+ task_name="sv",
+ mode="sv",
)
logging.info("sv_model: {}".format(sv_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
@@ -116,7 +105,6 @@
embedding, ref_embedding, similarity_score
"""
- assert check_argument_types()
self.sv_model.eval()
embedding = self.calculate_embedding(speech)
ref_emb, score = None, None
@@ -125,39 +113,4 @@
score = torch.cosine_similarity(embedding, ref_emb)
results = (embedding, ref_emb, score)
- assert check_return_type(results)
return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ):
- """Build Speech2Xvector instance from the pretrained model.
-
- Args:
- model_tag (Optional[str]): Model tag of the pretrained models.
- Currently, the tags of espnet_model_zoo are supported.
-
- Returns:
- Speech2Xvector: Speech2Xvector instance.
-
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2Xvector(**kwargs)
-
-
-
-
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index dbddd9f..2f9e276 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,20 +7,6 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -29,62 +15,58 @@
import numpy as np
import torch
from kaldiio import WriteHelper
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
-from funasr.torch_utils.device_funcs import to_device
+from funasr.bin.sv_infer import Speech2Xvector
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils.misc import statistic_model_parameters
-from funasr.bin.sv_infer import Speech2Xvector
+
def inference_sv(
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- sv_train_config: Optional[str] = "sv.yaml",
- sv_model_file: Optional[str] = "sv.pb",
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- param_dict: Optional[dict] = None,
- **kwargs,
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ sv_train_config: Optional[str] = "sv.yaml",
+ sv_model_file: Optional[str] = "sv.pb",
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ sv_threshold: float = 0.9465,
+ param_dict: Optional[dict] = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2xvector
speech2xvector_kwargs = dict(
sv_train_config=sv_train_config,
@@ -95,37 +77,33 @@
embedding_node=embedding_node
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector.from_pretrained(
- model_tag=model_tag,
- **speech2xvector_kwargs,
- )
+ speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
speech2xvector.sv_model.eval()
-
+
def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
+
# 3. Build data-iterator
- loader = SVTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="sv",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
+ use_collate_fn=False,
)
-
+
# 7 .Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
embd_writer, ref_embd_writer, score_writer = None, None, None
@@ -139,7 +117,7 @@
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
embedding, ref_embedding, score = speech2xvector(**batch)
# Only supporting batch_size==1
key = keys[0]
@@ -161,18 +139,16 @@
score_writer = open(os.path.join(output_path, "score.txt"), "w")
ref_embd_writer(key, ref_embedding[0].cpu().numpy())
score_writer.write("{} {:.6f}\n".format(key, normalized_score))
-
+
if output_path is not None:
embd_writer.close()
if ref_embd_writer is not None:
ref_embd_writer.close()
score_writer.close()
-
+
return sv_result_list
-
+
return _forward
-
-
def inference_launch(mode, **kwargs):
@@ -182,6 +158,7 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
index dc565d0..6ec83a8 100755
--- a/funasr/bin/tokenize_text.py
+++ b/funasr/bin/tokenize_text.py
@@ -7,7 +7,6 @@
from typing import List
from typing import Optional
-from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.text.build_tokenizer import build_tokenizer
@@ -81,7 +80,6 @@
cleaner: Optional[str],
g2p: Optional[str],
):
- assert check_argument_types()
logging.basicConfig(
level=log_level,
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
index 4ddcba4..ede579c 100644
--- a/funasr/bin/tp_infer.py
+++ b/funasr/bin/tp_infer.py
@@ -1,57 +1,33 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-from optparse import Option
-import sys
-import json
from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Dict
import numpy as np
import torch
-from typeguard import check_argument_types
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.tasks.asr import ASRTaskAligner as ASRTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
+from funasr.torch_utils.device_funcs import to_device
class Speech2Timestamp:
def __init__(
- self,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- timestamp_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- dtype: str = "float32",
- **kwargs,
+ self,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ timestamp_cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ dtype: str = "float32",
+ **kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
- tp_model, tp_train_args = ASRTask.build_model_from_file(
- timestamp_infer_config, timestamp_model_file, device=device
+ tp_model, tp_train_args = build_model_from_file(
+ timestamp_infer_config, timestamp_model_file, cmvn_file=None, device=device, task_name="asr", mode="tp"
)
if 'cuda' in device:
tp_model = tp_model.cuda() # force model to cuda
@@ -59,13 +35,12 @@
frontend = None
if tp_train_args.frontend is not None:
frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
-
+
logging.info("tp_model: {}".format(tp_model))
logging.info("tp_train_args: {}".format(tp_train_args))
tp_model.to(dtype=getattr(torch, dtype)).eval()
logging.info(f"Decoding device={device}, dtype={dtype}")
-
self.tp_model = tp_model
self.tp_train_args = tp_train_args
@@ -79,15 +54,14 @@
self.encoder_downsampling_factor = 1
if tp_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
-
+
@torch.no_grad()
def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- text_lengths: Union[torch.Tensor, np.ndarray] = None
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ text_lengths: Union[torch.Tensor, np.ndarray] = None
):
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -113,8 +87,6 @@
enc = enc[0]
# c. Forward Predictor
- _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
+ _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len,
+ text_lengths.to(self.device) + 1)
return us_alphas, us_peaks
-
-
-
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index a8d67ef..6c10254 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -8,87 +8,64 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-from optparse import Option
-import sys
-import json
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Dict
import numpy as np
import torch
-from typeguard import check_argument_types
-from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.tasks.asr import ASRTaskAligner as ASRTask
-from funasr.torch_utils.device_funcs import to_device
+from funasr.fileio.datadir_writer import DatadirWriter
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_infer import Speech2Timestamp
+
def inference_tp(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- timestamp_infer_config: Optional[str],
- timestamp_model_file: Optional[str],
- timestamp_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- split_with_space: bool = True,
- seg_dict_file: Optional[str] = None,
- **kwargs,
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ timestamp_infer_config: Optional[str],
+ timestamp_model_file: Optional[str],
+ timestamp_cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ split_with_space: bool = True,
+ seg_dict_file: Optional[str] = None,
+ **kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2vadsegment
speechtext2timestamp_kwargs = dict(
timestamp_infer_config=timestamp_infer_config,
@@ -99,7 +76,7 @@
)
logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs)
-
+
preprocessor = LMPreprocessor(
train=False,
token_type=speechtext2timestamp.tp_train_args.token_type,
@@ -112,21 +89,21 @@
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
)
-
+
if output_dir is not None:
writer = DatadirWriter(output_dir)
tp_writer = writer[f"timestamp_prediction"]
# ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
else:
tp_writer = None
-
+
def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs
):
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
writer = None
@@ -140,32 +117,31 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speechtext2timestamp.tp_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=preprocessor,
- collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
-
+
tp_result_list = []
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
+
logging.info("timestamp predicting, utt_id: {}".format(keys))
_batch = {'speech': batch['speech'],
'speech_lengths': batch['speech_lengths'],
'text_lengths': batch['text_lengths']}
us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
-
+
for batch_id in range(_bs):
key = keys[batch_id]
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
@@ -178,10 +154,8 @@
tp_writer["tp_time"][key + '#'] = str(ts_list)
tp_result_list.append(item)
return tp_result_list
-
+
return _forward
-
-
def inference_launch(mode, **kwargs):
@@ -190,6 +164,7 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -306,7 +281,6 @@
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
-
if __name__ == "__main__":
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 21e1943..1dc3fb5 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -299,7 +301,7 @@
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
index e1698d0..c60a8f1 100644
--- a/funasr/bin/vad_infer.py
+++ b/funasr/bin/vad_infer.py
@@ -1,42 +1,22 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-import os
-import sys
-import json
+import math
from pathlib import Path
-from typing import Any
+from typing import Dict
from typing import List
-from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
-import math
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-
+from funasr.torch_utils.device_funcs import to_device
class Speech2VadSegment:
@@ -61,11 +41,10 @@
dtype: str = "float32",
**kwargs,
):
- assert check_argument_types()
# 1. Build vad model
- vad_model, vad_infer_args = VADTask.build_model_from_file(
- vad_infer_config, vad_model_file, device
+ vad_model, vad_infer_args = build_model_from_file(
+ vad_infer_config, vad_model_file, None, device, task_name="vad"
)
frontend = None
if vad_infer_args.frontend is not None:
@@ -95,7 +74,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -128,12 +106,13 @@
"in_cache": in_cache
}
# a. To device
- #batch = to_device(batch, device=self.device)
+ # batch = to_device(batch, device=self.device)
segments_part, in_cache = self.vad_model(**batch)
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
return fbanks, segments
+
class Speech2VadSegmentOnline(Speech2VadSegment):
"""Speech2VadSegmentOnline class
@@ -146,13 +125,13 @@
[[10, 230], [245, 450], ...]
"""
+
def __init__(self, **kwargs):
super(Speech2VadSegmentOnline, self).__init__(**kwargs)
vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
self.frontend = None
if self.vad_infer_args.frontend is not None:
self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
-
@torch.no_grad()
def __call__(
@@ -167,7 +146,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -198,5 +176,3 @@
# in_cache.update(batch['in_cache'])
# in_cache = {key: value for key, value in batch['in_cache'].items()}
return fbanks, segments, in_cache
-
-
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index b17d058..47af011 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,58 +1,33 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
+
torch.set_num_threads(1)
import argparse
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-import os
-import sys
import json
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Union
-from typing import Dict
-import math
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
+
def inference_vad(
batch_size: int,
@@ -71,10 +46,8 @@
num_workers: int = 1,
**kwargs,
):
- assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
-
logging.basicConfig(
level=log_level,
@@ -112,16 +85,14 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="vad",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
finish_count = 0
@@ -157,6 +128,7 @@
return _forward
+
def inference_vad_online(
batch_size: int,
ngpu: int,
@@ -174,8 +146,6 @@
num_workers: int = 1,
**kwargs,
):
- assert check_argument_types()
-
logging.basicConfig(
level=log_level,
@@ -214,16 +184,14 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="vad",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
finish_count = 0
@@ -273,8 +241,6 @@
return _forward
-
-
def inference_launch(mode, **kwargs):
if mode == "offline":
return inference_vad(**kwargs)
@@ -283,6 +249,7 @@
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -405,5 +372,6 @@
inference_pipeline = inference_launch(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
+
if __name__ == "__main__":
main()
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
index 517c85b..632c134 100644
--- a/funasr/build_utils/build_args.py
+++ b/funasr/build_utils/build_args.py
@@ -41,7 +41,7 @@
"--cmvn_file",
type=str_or_none,
default=None,
- help="The file path of noise scp file.",
+ help="The path of cmvn file.",
)
elif args.task_name == "pretrain":
@@ -75,12 +75,29 @@
default=None,
help="The number of input dimension of the feature",
)
+ task_parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The path of cmvn file.",
+ )
elif args.task_name == "diar":
from funasr.build_utils.build_diar_model import class_choices_list
for class_choices in class_choices_list:
class_choices.add_arguments(task_parser)
+ elif args.task_name == "sv":
+ from funasr.build_utils.build_sv_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index 46c11b0..a76b204 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -20,16 +20,22 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
+
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -39,6 +45,7 @@
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
@@ -86,10 +93,13 @@
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
+ neatcontextual_paraformer=NeatContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
+ sa_asr=SAASRModel,
+
),
default="asr",
)
@@ -106,6 +116,27 @@
chunk_conformer=ConformerChunkEncoder,
),
default="rnn",
+)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
)
encoder_choices2 = ClassChoices(
"encoder2",
@@ -131,6 +162,7 @@
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
),
default="rnn",
)
@@ -222,24 +254,33 @@
rnnt_decoder_choices,
# --joint_network and --joint_network_conf
joint_network_choices,
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
]
def build_asr_model(args):
# token_list
- if args.token_list is not None:
- with open(args.token_list) as f:
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
args.token_list = list(token_list)
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
else:
+ token_list = None
vocab_size = None
# frontend
- if args.input_size is None:
+ if hasattr(args, "input_size") and args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
- if args.frontend == 'wav_frontend':
+ if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
@@ -248,7 +289,7 @@
args.frontend = None
args.frontend_conf = {}
frontend = None
- input_size = args.input_size
+ input_size = args.input_size if hasattr(args, "input_size") else None
# data augmentation for spectrogram
if args.specaug is not None:
@@ -260,7 +301,10 @@
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
+ if args.model == "mfcca":
+ normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
+ else:
+ normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
@@ -294,7 +338,8 @@
token_list=token_list,
**args.model_conf,
)
- elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+ elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
+ "contextual_paraformer", "neatcontextual_paraformer"]:
# predictor
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
@@ -363,10 +408,15 @@
**args.model_conf,
)
elif args.model == "timestamp_prediction":
+ # predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
model_class = model_choices.get_class(args.model)
model = model_class(
frontend=frontend,
encoder=encoder,
+ predictor=predictor,
token_list=token_list,
**args.model_conf,
)
@@ -413,6 +463,33 @@
joint_network=joint_network,
**args.model_conf,
)
+ elif args.model == "sa_asr":
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder.output_size(),
+ **args.decoder_conf,
+ )
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+ )
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
index 6406404..0ea3127 100644
--- a/funasr/build_utils/build_diar_model.py
+++ b/funasr/build_utils/build_diar_model.py
@@ -178,14 +178,18 @@
def build_diar_model(args):
# token_list
- if args.token_list is not None:
- with open(args.token_list) as f:
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
else:
- vocab_size = None
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
# frontend
if args.input_size is None:
@@ -205,7 +209,7 @@
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
- if args.model_name == "sond":
+ if args.model == "sond":
# data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
@@ -243,11 +247,7 @@
# decoder
decoder_class = decoder_choices.get_class(args.decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder.output_size(),
- **args.decoder_conf,
- )
+ decoder = decoder_class(**args.decoder_conf)
# logger aggregator
if getattr(args, "label_aggregator", None) is not None:
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
index 8f4a958..f78a20e 100644
--- a/funasr/build_utils/build_lm_model.py
+++ b/funasr/build_utils/build_lm_model.py
@@ -34,10 +34,14 @@
def build_lm_model(args):
# token_list
- if args.token_list is not None:
- with open(args.token_list) as f:
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
else:
@@ -47,6 +51,7 @@
lm_class = lm_choices.get_class(args.lm)
lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
+ args.model = args.model if hasattr(args, "model") else "lm"
model_class = model_choices.get_class(args.model)
model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index 13a6faa..be8f910 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -1,9 +1,10 @@
from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_diar_model import build_diar_model
from funasr.build_utils.build_lm_model import build_lm_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
-from funasr.build_utils.build_diar_model import build_diar_model
def build_model(args):
@@ -19,6 +20,8 @@
model = build_vad_model(args)
elif args.task_name == "diar":
model = build_diar_model(args)
+ elif args.task_name == "sv":
+ model = build_sv_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
new file mode 100644
index 0000000..26542cd
--- /dev/null
+++ b/funasr/build_utils/build_model_from_file.py
@@ -0,0 +1,191 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Union
+
+import torch
+import yaml
+
+from funasr.build_utils.build_model import build_model
+from funasr.models.base_model import FunASRModel
+
+
+def build_model_from_file(
+ config_file: Union[Path, str] = None,
+ model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ task_name: str = "asr",
+ mode: str = "paraformer",
+):
+ """Build model from the files.
+
+ This method is used for inference or fine-tuning.
+
+ Args:
+ config_file: The yaml file saved when training.
+ model_file: The model file saved when training.
+ device: Device type, "cpu", "cuda", or "cuda:N".
+
+ """
+ if config_file is None:
+ assert model_file is not None, (
+ "The argument 'model_file' must be provided "
+ "if the argument 'config_file' is not specified."
+ )
+ config_file = Path(model_file).parent / "config.yaml"
+ else:
+ config_file = Path(config_file)
+
+ with config_file.open("r", encoding="utf-8") as f:
+ args = yaml.safe_load(f)
+ if cmvn_file is not None:
+ args["cmvn_file"] = cmvn_file
+ args = argparse.Namespace(**args)
+ args.task_name = task_name
+ model = build_model(args)
+ if not isinstance(model, FunASRModel):
+ raise RuntimeError(
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
+ )
+ model.to(device)
+ model_dict = dict()
+ model_name_pth = None
+ if model_file is not None:
+ logging.info("model_file is {}".format(model_file))
+ if device == "cuda":
+ device = f"cuda:{torch.cuda.current_device()}"
+ model_dir = os.path.dirname(model_file)
+ model_name = os.path.basename(model_file)
+ if "model.ckpt-" in model_name or ".bin" in model_name:
+ model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
+ '.pb')) if ".bin" in model_name else os.path.join(
+ model_dir, "{}.pb".format(model_name))
+ if os.path.exists(model_name_pth):
+ logging.info("model_file is load from pth: {}".format(model_name_pth))
+ model_dict = torch.load(model_name_pth, map_location=device)
+ else:
+ model_dict = convert_tf2torch(model, model_file, mode)
+ model.load_state_dict(model_dict)
+ else:
+ model_dict = torch.load(model_file, map_location=device)
+ if task_name == "diar" and mode == "sond":
+ model_dict = fileter_model_dict(model_dict, model.state_dict())
+ if task_name == "vad":
+ model.encoder.load_state_dict(model_dict)
+ else:
+ model.load_state_dict(model_dict)
+ if model_name_pth is not None and not os.path.exists(model_name_pth):
+ torch.save(model_dict, model_name_pth)
+ logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+ return model, args
+
+
+def convert_tf2torch(
+ model,
+ ckpt,
+ mode,
+):
+ assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
+ logging.info("start convert tf model to torch model")
+ from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+ var_dict_tf = load_tf_dict(ckpt)
+ var_dict_torch = model.state_dict()
+ var_dict_torch_update = dict()
+ if mode == "uniasr":
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # encoder2
+ var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor2
+ var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder2
+ var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # stride_conv
+ var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ elif mode == "paraformer":
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ elif "mode" == "sond":
+ if model.encoder is not None:
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # speaker encoder
+ if model.speaker_encoder is not None:
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # cd scorer
+ if model.cd_scorer is not None:
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # ci scorer
+ if model.ci_scorer is not None:
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ if model.decoder is not None:
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ elif "mode" == "sv":
+ # speech encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # pooling layer
+ var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ else:
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ return var_dict_torch_update
+
+ return var_dict_torch_update
+
+
+def fileter_model_dict(src_dict: dict, dest_dict: dict):
+ from collections import OrderedDict
+ new_dict = OrderedDict()
+ for key, value in src_dict.items():
+ if key in dest_dict:
+ new_dict[key] = value
+ else:
+ logging.info("{} is no longer needed in this model.".format(key))
+ for key, value in dest_dict.items():
+ if key not in new_dict:
+ logging.warning("{} is missed in checkpoint.".format(key))
+ return new_dict
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
new file mode 100644
index 0000000..02fc263
--- /dev/null
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -0,0 +1,65 @@
+import numpy as np
+from torch.utils.data import DataLoader
+
+from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+
+
+def build_streaming_iterator(
+ task_name,
+ preprocess_args,
+ data_path_and_name_and_type,
+ key_file: str = None,
+ batch_size: int = 1,
+ fs: dict = None,
+ mc: bool = False,
+ dtype: str = np.float32,
+ num_workers: int = 1,
+ use_collate_fn: bool = True,
+ preprocess_fn=None,
+ ngpu: int = 0,
+ train: bool = False,
+) -> DataLoader:
+ """Build DataLoader using iterable dataset"""
+
+ # preprocess
+ if preprocess_fn is not None:
+ preprocess_fn = preprocess_fn
+ elif preprocess_args is not None:
+ preprocess_args.task_name = task_name
+ preprocess_fn = build_preprocess(preprocess_args, train)
+ else:
+ preprocess_fn = None
+
+ # collate
+ if not use_collate_fn:
+ collate_fn = None
+ elif task_name in ["punc", "lm"]:
+ collate_fn = CommonCollateFn(int_pad_value=0)
+ else:
+ collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+ if collate_fn is not None:
+ kwargs = dict(collate_fn=collate_fn)
+ else:
+ kwargs = {}
+
+ dataset = IterableESPnetDataset(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ fs=fs,
+ mc=mc,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ )
+ if dataset.apply_utt2category:
+ kwargs.update(batch_size=1)
+ else:
+ kwargs.update(batch_size=batch_size)
+
+ return DataLoader(
+ dataset=dataset,
+ pin_memory=ngpu > 0,
+ num_workers=num_workers,
+ **kwargs,
+ )
diff --git a/funasr/build_utils/build_sv_model.py b/funasr/build_utils/build_sv_model.py
new file mode 100644
index 0000000..55df75a
--- /dev/null
+++ b/funasr/build_utils/build_sv_model.py
@@ -0,0 +1,256 @@
+import logging
+
+import torch
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.sv_decoder import DenseDecoder
+from funasr.models.e2e_sv import ESPnetSVModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.pooling.statistic_pooling import StatisticPooling
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ espnet=ESPnetSVModel,
+ ),
+ type_check=FunASRModel,
+ default="espnet",
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ linear=LinearProjection,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ resnet34=ResNet34,
+ resnet34_sp_l2reg=ResNet34_SP_L2Reg,
+ rnn=RNNEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="resnet34",
+)
+postencoder_choices = ClassChoices(
+ name="postencoder",
+ classes=dict(
+ hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+ ),
+ type_check=AbsPostEncoder,
+ default=None,
+ optional=True,
+)
+pooling_choices = ClassChoices(
+ name="pooling_type",
+ classes=dict(
+ statistic=StatisticPooling,
+ ),
+ type_check=torch.nn.Module,
+ default="statistic",
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ dense=DenseDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="dense",
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --postencoder and --postencoder_conf
+ postencoder_choices,
+ # --pooling and --pooling_conf
+ pooling_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+]
+
+
+def build_sv_model(args):
+ # token_list
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Speaker number: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 6. Post-encoder block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ encoder_output_size = encoder.output_size()
+ if getattr(args, "postencoder", None) is not None:
+ postencoder_class = postencoder_choices.get_class(args.postencoder)
+ postencoder = postencoder_class(
+ input_size=encoder_output_size, **args.postencoder_conf
+ )
+ encoder_output_size = postencoder.output_size()
+ else:
+ postencoder = None
+
+ # 7. Pooling layer
+ pooling_class = pooling_choices.get_class(args.pooling_type)
+ pooling_dim = (2, 3)
+ eps = 1e-12
+ if hasattr(args, "pooling_type_conf"):
+ if "pooling_dim" in args.pooling_type_conf:
+ pooling_dim = args.pooling_type_conf["pooling_dim"]
+ if "eps" in args.pooling_type_conf:
+ eps = args.pooling_type_conf["eps"]
+ pooling_layer = pooling_class(
+ pooling_dim=pooling_dim,
+ eps=eps,
+ )
+ if args.pooling_type == "statistic":
+ encoder_output_size *= 2
+
+ # 8. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+
+ # 7. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("espnet")
+ model = model_class(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ pooling_layer=pooling_layer,
+ decoder=decoder,
+ **args.model_conf,
+ )
+
+ # FIXME(kamo): Should be done in model?
+ # 8. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
index aff99b5..03aa780 100644
--- a/funasr/build_utils/build_trainer.py
+++ b/funasr/build_utils/build_trainer.py
@@ -25,7 +25,6 @@
import torch
import torch.nn
import torch.optim
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.main_funcs.average_nbest_models import average_nbest_models
@@ -118,7 +117,6 @@
def build_options(self, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
- assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
@@ -156,7 +154,6 @@
def run(self) -> None:
"""Perform training. This method performs the main process of training."""
- assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
model = self.model
optimizers = self.optimizers
@@ -522,7 +519,6 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
- assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
@@ -758,7 +754,6 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
- assert check_argument_types()
ngpu = options.ngpu
# no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
index 76eb09b..6a840cf 100644
--- a/funasr/build_utils/build_vad_model.py
+++ b/funasr/build_utils/build_vad_model.py
@@ -50,6 +50,10 @@
def build_vad_model(args):
# frontend
+ if not hasattr(args, "cmvn_file"):
+ args.cmvn_file = None
+ if not hasattr(args, "init"):
+ args.init = None
if args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
diff --git a/funasr/datasets/collate_fn.py b/funasr/datasets/collate_fn.py
index d34d610..cbc1f0b 100644
--- a/funasr/datasets/collate_fn.py
+++ b/funasr/datasets/collate_fn.py
@@ -6,8 +6,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list
@@ -22,7 +20,6 @@
not_sequence: Collection[str] = (),
max_sample_size=None
):
- assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
@@ -53,7 +50,6 @@
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
- assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
@@ -79,7 +75,6 @@
output[key + "_lengths"] = lens
output = (uttids, output)
- assert check_return_type(output)
return output
def crop_to_max_size(feature, target_size):
@@ -99,7 +94,6 @@
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
# mainly for pre-training
- assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
@@ -131,5 +125,4 @@
output[key + "_lengths"] = lens
output = (uttids, output)
- assert check_return_type(output)
return output
\ No newline at end of file
diff --git a/funasr/datasets/dataset.py b/funasr/datasets/dataset.py
index 979479c..407f6aa 100644
--- a/funasr/datasets/dataset.py
+++ b/funasr/datasets/dataset.py
@@ -23,8 +23,6 @@
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
@@ -37,7 +35,6 @@
class AdapterForSoundScpReader(collections.abc.Mapping):
def __init__(self, loader, dtype=None):
- assert check_argument_types()
self.loader = loader
self.dtype = dtype
self.rate = None
@@ -284,7 +281,6 @@
max_cache_fd: int = 0,
dest_sample_rate: int = 16000,
):
- assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
@@ -379,7 +375,6 @@
return _mes
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
- assert check_argument_types()
# Change integer-id to string-id
if isinstance(uid, int):
@@ -444,5 +439,4 @@
self.cache[uid] = data
retval = uid, data
- assert check_return_type(retval)
return retval
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 4b2fb1a..6398e0c 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -14,8 +14,8 @@
import numpy as np
import torch
import torchaudio
+import soundfile
from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
@@ -66,8 +66,17 @@
bytes = f.read()
return load_bytes(bytes)
+def load_wav(input):
+ try:
+ return torchaudio.load(input)[0].numpy()
+ except:
+ waveform, _ = soundfile.read(input, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ return np.expand_dims(waveform, axis=0)
+
DATA_TYPES = {
- "sound": lambda x: torchaudio.load(x)[0].numpy(),
+ "sound": load_wav,
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
@@ -111,7 +120,6 @@
int_dtype: str = "long",
key_file: str = None,
):
- assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index aa5d9be..7a1a906 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -6,7 +6,6 @@
import sentencepiece as spm
from torch.utils.data import DataLoader
-from typeguard import check_argument_types
from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory
@@ -43,7 +42,6 @@
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
- assert check_argument_types()
self.model = str(model)
self.sp = None
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 68b63e1..5f2c2c6 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
+import numpy as np
+import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,14 @@
sample_dict["key"] = key
elif data_type == "sound":
key, path = item.strip().split()
- waveform, sampling_rate = torchaudio.load(path)
+ try:
+ waveform, sampling_rate = torchaudio.load(path)
+ except:
+ waveform, sampling_rate = soundfile.read(path, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ waveform = np.expand_dims(waveform, axis=0)
+ waveform = torch.tensor(waveform)
if self.frontend_conf is not None:
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 758c750..cb4288c 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -11,8 +11,6 @@
import numpy as np
import scipy.signal
import soundfile
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@@ -268,7 +266,6 @@
def _speech_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, Union[str, np.ndarray]]:
- assert check_argument_types()
if self.speech_name in data:
if self.train and (self.rirs is not None or self.noises is not None):
speech = data[self.speech_name]
@@ -355,7 +352,6 @@
speech = data[self.speech_name]
ma = np.max(np.abs(speech))
data[self.speech_name] = speech * self.speech_volume_normalize / ma
- assert check_return_type(data)
return data
def _text_process(
@@ -372,13 +368,11 @@
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
- assert check_argument_types()
data = self._speech_process(data)
data = self._text_process(data)
@@ -445,7 +439,6 @@
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
@@ -502,13 +495,11 @@
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[text_n] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
- assert check_argument_types()
if self.speech_name in data:
# Nothing now: candidates:
@@ -612,7 +603,6 @@
tokens = self.tokenizer[i].text2tokens(text)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
@@ -690,7 +680,6 @@
def __call__(
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
) -> Dict[str, Union[list, np.ndarray]]:
- assert check_argument_types()
# Split words.
if isinstance(data[self.text_name], str):
split_text = self.split_words(data[self.text_name])
diff --git a/funasr/datasets/small_datasets/__init__.py b/funasr/datasets/small_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/small_datasets/__init__.py
diff --git a/funasr/datasets/small_datasets/collate_fn.py b/funasr/datasets/small_datasets/collate_fn.py
index 573f581..5fd4162 100644
--- a/funasr/datasets/small_datasets/collate_fn.py
+++ b/funasr/datasets/small_datasets/collate_fn.py
@@ -6,8 +6,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list
@@ -22,7 +20,6 @@
not_sequence: Collection[str] = (),
max_sample_size=None
):
- assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
@@ -53,7 +50,6 @@
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
- assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
@@ -79,7 +75,6 @@
output[key + "_lengths"] = lens
output = (uttids, output)
- assert check_return_type(output)
return output
def crop_to_max_size(feature, target_size):
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
index 123f109..bee9f50 100644
--- a/funasr/datasets/small_datasets/dataset.py
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -15,8 +15,6 @@
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.sound_scp import SoundScpReader
@@ -24,7 +22,6 @@
class AdapterForSoundScpReader(collections.abc.Mapping):
def __init__(self, loader, dtype=None):
- assert check_argument_types()
self.loader = loader
self.dtype = dtype
self.rate = None
@@ -112,7 +109,6 @@
speed_perturb: Union[list, tuple] = None,
mode: str = "train",
):
- assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
@@ -207,7 +203,6 @@
return _mes
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
- assert check_argument_types()
# Change integer-id to string-id
if isinstance(uid, int):
@@ -265,5 +260,4 @@
data[name] = value
retval = uid, data
- assert check_return_type(retval)
return retval
diff --git a/funasr/datasets/small_datasets/length_batch_sampler.py b/funasr/datasets/small_datasets/length_batch_sampler.py
index 8ee8bdc..28404e3 100644
--- a/funasr/datasets/small_datasets/length_batch_sampler.py
+++ b/funasr/datasets/small_datasets/length_batch_sampler.py
@@ -4,7 +4,6 @@
from typing import Tuple
from typing import Union
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@@ -21,7 +20,6 @@
drop_last: bool = False,
padding: bool = True,
):
- assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
index d80f48a..0ebf325 100644
--- a/funasr/datasets/small_datasets/preprocessor.py
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -10,8 +10,6 @@
import numpy as np
import scipy.signal
import soundfile
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
@@ -260,7 +258,6 @@
def _speech_process(
self, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, Union[str, np.ndarray]]:
- assert check_argument_types()
if self.speech_name in data:
if self.train and (self.rirs is not None or self.noises is not None):
speech = data[self.speech_name]
@@ -347,7 +344,6 @@
speech = data[self.speech_name]
ma = np.max(np.abs(speech))
data[self.speech_name] = speech * self.speech_volume_normalize / ma
- assert check_return_type(data)
return data
def _text_process(
@@ -365,13 +361,11 @@
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
- assert check_argument_types()
data = self._speech_process(data)
data = self._text_process(data)
@@ -439,7 +433,6 @@
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[self.text_name] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
@@ -496,13 +489,11 @@
tokens = self.tokenizer.text2tokens(text)
text_ints = self.token_id_converter.tokens2ids(tokens)
data[text_n] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
- assert check_argument_types()
if self.speech_name in data:
# Nothing now: candidates:
@@ -606,7 +597,6 @@
tokens = self.tokenizer[i].text2tokens(text)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
- assert check_return_type(data)
return data
@@ -685,7 +675,6 @@
def __call__(
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
) -> Dict[str, Union[list, np.ndarray]]:
- assert check_argument_types()
# Split words.
if isinstance(data[self.text_name], str):
split_text = self.split_words(data[self.text_name])
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index c02c299..f31f960 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -1,7 +1,6 @@
import json
from typing import Union, Dict
from pathlib import Path
-from typeguard import check_argument_types
import os
import logging
@@ -10,7 +9,7 @@
from funasr.export.models import get_model
import numpy as np
import random
-from funasr.utils.types import str2bool
+from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
@@ -24,8 +23,8 @@
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
+ model_revision: str = None,
):
- assert check_argument_types()
self.set_all_random_seed(0)
self.cache_dir = cache_dir
@@ -41,6 +40,7 @@
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
+ self.model_revision = model_revision
def _export(
@@ -171,7 +171,7 @@
model_dir = tag_name
if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
- model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
+ model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
self.cache_dir = model_dir
if mode is None:
@@ -192,6 +192,7 @@
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
+ self.export_config["feats_dim"] = 560
elif mode.startswith('offline'):
from funasr.tasks.vad import VADTask
config = os.path.join(model_dir, 'vad.yaml')
@@ -229,40 +230,42 @@
# model_script = torch.jit.script(model)
model_script = model #torch.jit.trace(model)
model_path = os.path.join(path, f'{model.model_name}.onnx')
-
- torch.onnx.export(
- model_script,
- dummy_input,
- model_path,
- verbose=verbose,
- opset_version=14,
- input_names=model.get_input_names(),
- output_names=model.get_output_names(),
- dynamic_axes=model.get_dynamic_axes()
- )
+ if not os.path.exists(model_path):
+ torch.onnx.export(
+ model_script,
+ dummy_input,
+ model_path,
+ verbose=verbose,
+ opset_version=14,
+ input_names=model.get_input_names(),
+ output_names=model.get_output_names(),
+ dynamic_axes=model.get_dynamic_axes()
+ )
if self.quant:
from onnxruntime.quantization import QuantType, quantize_dynamic
import onnx
quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
- onnx_model = onnx.load(model_path)
- nodes = [n.name for n in onnx_model.graph.node]
- nodes_to_exclude = [m for m in nodes if 'output' in m]
- quantize_dynamic(
- model_input=model_path,
- model_output=quant_model_path,
- op_types_to_quantize=['MatMul'],
- per_channel=True,
- reduce_range=False,
- weight_type=QuantType.QUInt8,
- nodes_to_exclude=nodes_to_exclude,
- )
+ if not os.path.exists(quant_model_path):
+ onnx_model = onnx.load(model_path)
+ nodes = [n.name for n in onnx_model.graph.node]
+ nodes_to_exclude = [m for m in nodes if 'output' in m]
+ quantize_dynamic(
+ model_input=model_path,
+ model_output=quant_model_path,
+ op_types_to_quantize=['MatMul'],
+ per_channel=True,
+ reduce_range=False,
+ weight_type=QuantType.QUInt8,
+ nodes_to_exclude=nodes_to_exclude,
+ )
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
- parser.add_argument('--model-name', type=str, required=True)
+ # parser.add_argument('--model-name', type=str, required=True)
+ parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
@@ -270,6 +273,7 @@
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+ parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
args = parser.parse_args()
export_model = ModelExport(
@@ -280,5 +284,8 @@
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
+ model_revision=args.model_revision,
)
- export_model.export(args.model_name)
+ for model_name in args.model_name:
+ print("export model: {}".format(model_name))
+ export_model.export(model_name)
diff --git a/funasr/fileio/datadir_writer.py b/funasr/fileio/datadir_writer.py
index 67ec61c..e555ebc 100644
--- a/funasr/fileio/datadir_writer.py
+++ b/funasr/fileio/datadir_writer.py
@@ -2,8 +2,6 @@
from typing import Union
import warnings
-from typeguard import check_argument_types
-from typeguard import check_return_type
class DatadirWriter:
@@ -20,7 +18,6 @@
"""
def __init__(self, p: Union[Path, str]):
- assert check_argument_types()
self.path = Path(p)
self.chilidren = {}
self.fd = None
@@ -31,7 +28,6 @@
return self
def __getitem__(self, key: str) -> "DatadirWriter":
- assert check_argument_types()
if self.fd is not None:
raise RuntimeError("This writer points out a file")
@@ -41,11 +37,9 @@
self.has_children = True
retval = self.chilidren[key]
- assert check_return_type(retval)
return retval
def __setitem__(self, key: str, value: str):
- assert check_argument_types()
if self.has_children:
raise RuntimeError("This writer points out a directory")
if key in self.keys:
diff --git a/funasr/fileio/npy_scp.py b/funasr/fileio/npy_scp.py
index 26666b6..2bd5b58 100644
--- a/funasr/fileio/npy_scp.py
+++ b/funasr/fileio/npy_scp.py
@@ -3,7 +3,6 @@
from typing import Union
import numpy as np
-from typeguard import check_argument_types
from funasr.fileio.read_text import read_2column_text
@@ -25,7 +24,6 @@
"""
def __init__(self, outdir: Union[Path, str], scpfile: Union[Path, str]):
- assert check_argument_types()
self.dir = Path(outdir)
self.dir.mkdir(parents=True, exist_ok=True)
scpfile = Path(scpfile)
@@ -73,7 +71,6 @@
"""
def __init__(self, fname: Union[Path, str]):
- assert check_argument_types()
self.fname = Path(fname)
self.data = read_2column_text(fname)
diff --git a/funasr/fileio/rand_gen_dataset.py b/funasr/fileio/rand_gen_dataset.py
index 2faef3a..699e67a 100644
--- a/funasr/fileio/rand_gen_dataset.py
+++ b/funasr/fileio/rand_gen_dataset.py
@@ -3,7 +3,6 @@
from typing import Union
import numpy as np
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
@@ -29,7 +28,6 @@
dtype: Union[str, np.dtype] = "float32",
loader_type: str = "csv_int",
):
- assert check_argument_types()
shape_file = Path(shape_file)
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
self.dtype = np.dtype(dtype)
@@ -68,7 +66,6 @@
dtype: Union[str, np.dtype] = "int64",
loader_type: str = "csv_int",
):
- assert check_argument_types()
shape_file = Path(shape_file)
self.utt2shape = load_num_sequence_text(shape_file, loader_type)
self.dtype = np.dtype(dtype)
diff --git a/funasr/fileio/read_text.py b/funasr/fileio/read_text.py
index e26e7a1..f140c31 100644
--- a/funasr/fileio/read_text.py
+++ b/funasr/fileio/read_text.py
@@ -4,7 +4,6 @@
from typing import List
from typing import Union
-from typeguard import check_argument_types
def read_2column_text(path: Union[Path, str]) -> Dict[str, str]:
@@ -19,7 +18,6 @@
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
- assert check_argument_types()
data = {}
with Path(path).open("r", encoding="utf-8") as f:
@@ -47,7 +45,6 @@
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
- assert check_argument_types()
if loader_type == "text_int":
delimiter = " "
dtype = int
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index c752fe6..b912f1e 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -1,17 +1,84 @@
import collections.abc
from pathlib import Path
-from typing import Union
+from typing import List, Tuple, Union
import random
import numpy as np
import soundfile
import librosa
-from typeguard import check_argument_types
import torch
import torchaudio
from funasr.fileio.read_text import read_2column_text
+
+def soundfile_read(
+ wavs: Union[str, List[str]],
+ dtype=None,
+ always_2d: bool = False,
+ concat_axis: int = 1,
+ start: int = 0,
+ end: int = None,
+ return_subtype: bool = False,
+) -> Tuple[np.array, int]:
+ if isinstance(wavs, str):
+ wavs = [wavs]
+
+ arrays = []
+ subtypes = []
+ prev_rate = None
+ prev_wav = None
+ for wav in wavs:
+ with soundfile.SoundFile(wav) as f:
+ f.seek(start)
+ if end is not None:
+ frames = end - start
+ else:
+ frames = -1
+ if dtype == "float16":
+ array = f.read(
+ frames,
+ dtype="float32",
+ always_2d=always_2d,
+ ).astype(dtype)
+ else:
+ array = f.read(frames, dtype=dtype, always_2d=always_2d)
+ rate = f.samplerate
+ subtype = f.subtype
+ subtypes.append(subtype)
+
+ if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
+ # array: (Time, Channel)
+ array = array[:, None]
+
+ if prev_wav is not None:
+ if prev_rate != rate:
+ raise RuntimeError(
+ f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
+ f"{prev_rate} != {rate}"
+ )
+
+ dim1 = arrays[0].shape[1 - concat_axis]
+ dim2 = array.shape[1 - concat_axis]
+ if dim1 != dim2:
+ raise RuntimeError(
+ "Shapes must match with "
+ f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
+ )
+
+ prev_rate = rate
+ prev_wav = wav
+ arrays.append(array)
+
+ if len(arrays) == 1:
+ array = arrays[0]
+ else:
+ array = np.concatenate(arrays, axis=concat_axis)
+
+ if return_subtype:
+ return array, rate, subtypes
+ else:
+ return array, rate
class SoundScpReader(collections.abc.Mapping):
@@ -38,7 +105,6 @@
dest_sample_rate: int = 16000,
speed_perturb: Union[list, tuple] = None,
):
- assert check_argument_types()
self.fname = fname
self.dtype = dtype
self.always_2d = always_2d
@@ -111,7 +177,6 @@
format="wav",
dtype=None,
):
- assert check_argument_types()
self.dir = Path(outdir)
self.dir.mkdir(parents=True, exist_ok=True)
scpfile = Path(scpfile)
diff --git a/funasr/iterators/chunk_iter_factory.py b/funasr/iterators/chunk_iter_factory.py
index cec6370..5a54632 100644
--- a/funasr/iterators/chunk_iter_factory.py
+++ b/funasr/iterators/chunk_iter_factory.py
@@ -9,7 +9,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
@@ -51,7 +50,6 @@
collate_fn=None,
pin_memory: bool = False,
):
- assert check_argument_types()
assert all(len(x) == 1 for x in batches), "batch-size must be 1"
self.per_sample_iter_factory = SequenceIterFactory(
diff --git a/funasr/iterators/multiple_iter_factory.py b/funasr/iterators/multiple_iter_factory.py
index 088016c..3587a2a 100644
--- a/funasr/iterators/multiple_iter_factory.py
+++ b/funasr/iterators/multiple_iter_factory.py
@@ -4,7 +4,6 @@
from typing import Iterator
import numpy as np
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
@@ -16,7 +15,6 @@
seed: int = 0,
shuffle: bool = False,
):
- assert check_argument_types()
self.build_funcs = list(build_funcs)
self.seed = seed
self.shuffle = shuffle
diff --git a/funasr/iterators/sequence_iter_factory.py b/funasr/iterators/sequence_iter_factory.py
index 39d0834..41de37c 100644
--- a/funasr/iterators/sequence_iter_factory.py
+++ b/funasr/iterators/sequence_iter_factory.py
@@ -4,7 +4,6 @@
import numpy as np
from torch.utils.data import DataLoader
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.samplers.abs_sampler import AbsSampler
@@ -46,7 +45,6 @@
collate_fn=None,
pin_memory: bool = False,
):
- assert check_argument_types()
if not isinstance(batches, AbsSampler):
self.sampler = RawSampler(batches)
diff --git a/funasr/layers/global_mvn.py b/funasr/layers/global_mvn.py
index 8e43582..b94e9ca 100644
--- a/funasr/layers/global_mvn.py
+++ b/funasr/layers/global_mvn.py
@@ -4,7 +4,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.abs_normalize import AbsNormalize
@@ -28,7 +27,6 @@
norm_vars: bool = True,
eps: float = 1.0e-20,
):
- assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
diff --git a/funasr/layers/label_aggregation.py b/funasr/layers/label_aggregation.py
index 29a08a9..8366a79 100644
--- a/funasr/layers/label_aggregation.py
+++ b/funasr/layers/label_aggregation.py
@@ -1,5 +1,4 @@
import torch
-from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
@@ -13,7 +12,6 @@
hop_length: int = 128,
center: bool = True,
):
- assert check_argument_types()
super().__init__()
self.win_length = win_length
diff --git a/funasr/layers/mask_along_axis.py b/funasr/layers/mask_along_axis.py
index e49e621..416c4ea 100644
--- a/funasr/layers/mask_along_axis.py
+++ b/funasr/layers/mask_along_axis.py
@@ -1,6 +1,5 @@
import math
import torch
-from typeguard import check_argument_types
from typing import Sequence
from typing import Union
@@ -147,7 +146,6 @@
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
- assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
@@ -214,7 +212,6 @@
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
- assert check_argument_types()
if isinstance(mask_width_ratio_range, float):
mask_width_ratio_range = (0.0, mask_width_ratio_range)
if len(mask_width_ratio_range) != 2:
@@ -283,7 +280,6 @@
replace_with_zero: bool = True,
lfr_rate: int = 1,
):
- assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
diff --git a/funasr/layers/sinc_conv.py b/funasr/layers/sinc_conv.py
index 33df97f..ab16a73 100644
--- a/funasr/layers/sinc_conv.py
+++ b/funasr/layers/sinc_conv.py
@@ -5,7 +5,6 @@
"""Sinc convolutions."""
import math
import torch
-from typeguard import check_argument_types
from typing import Union
@@ -71,7 +70,6 @@
window_func: Window function on the filter, one of ["hamming", "none"].
fs (str, int, float): Sample rate of the input data
"""
- assert check_argument_types()
super().__init__()
window_funcs = {
"none": self.none_window,
@@ -208,7 +206,6 @@
torch.Tensor: Filter start frequenc铆es.
torch.Tensor: Filter stop frequencies.
"""
- assert check_argument_types()
# min and max bandpass edge frequencies
min_frequency = torch.tensor(30.0)
max_frequency = torch.tensor(fs * 0.5)
@@ -257,7 +254,6 @@
torch.Tensor: Filter start frequenc铆es.
torch.Tensor: Filter stop frequenc铆es.
"""
- assert check_argument_types()
# min and max BARK center frequencies by approximation
min_center_frequency = torch.tensor(70.0)
max_center_frequency = torch.tensor(fs * 0.45)
diff --git a/funasr/layers/stft.py b/funasr/layers/stft.py
index 376b5a3..dfb6919 100644
--- a/funasr/layers/stft.py
+++ b/funasr/layers/stft.py
@@ -5,7 +5,6 @@
import torch
from torch_complex.tensor import ComplexTensor
-from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.complex_utils import is_complex
@@ -30,7 +29,6 @@
normalized: bool = False,
onesided: bool = True,
):
- assert check_argument_types()
super().__init__()
self.n_fft = n_fft
if win_length is None:
diff --git a/funasr/layers/utterance_mvn.py b/funasr/layers/utterance_mvn.py
index 50f27cd..7722974 100644
--- a/funasr/layers/utterance_mvn.py
+++ b/funasr/layers/utterance_mvn.py
@@ -1,7 +1,6 @@
from typing import Tuple
import torch
-from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.layers.abs_normalize import AbsNormalize
@@ -14,7 +13,6 @@
norm_vars: bool = False,
eps: float = 1.0e-20,
):
- assert check_argument_types()
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
diff --git a/funasr/main_funcs/average_nbest_models.py b/funasr/main_funcs/average_nbest_models.py
index d8df949..96e1384 100644
--- a/funasr/main_funcs/average_nbest_models.py
+++ b/funasr/main_funcs/average_nbest_models.py
@@ -8,7 +8,6 @@
from io import BytesIO
import torch
-from typeguard import check_argument_types
from typing import Collection
from funasr.train.reporter import Reporter
@@ -34,7 +33,6 @@
nbest: Number of best model files to be averaged
suffix: A suffix added to the averaged model file name
"""
- assert check_argument_types()
if isinstance(nbest, int):
nbests = [nbest]
else:
diff --git a/funasr/main_funcs/collect_stats.py b/funasr/main_funcs/collect_stats.py
index 584b85a..ee2182c 100644
--- a/funasr/main_funcs/collect_stats.py
+++ b/funasr/main_funcs/collect_stats.py
@@ -11,7 +11,6 @@
import torch
from torch.nn.parallel import data_parallel
from torch.utils.data import DataLoader
-from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.fileio.npy_scp import NpyScpWriter
@@ -37,7 +36,6 @@
This method is used before executing train().
"""
- assert check_argument_types()
npy_scp_writers = {}
for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
diff --git a/funasr/models/ctc.py b/funasr/models/ctc.py
index 64b8710..d3c10fa 100644
--- a/funasr/models/ctc.py
+++ b/funasr/models/ctc.py
@@ -2,7 +2,6 @@
import torch
import torch.nn.functional as F
-from typeguard import check_argument_types
class CTC(torch.nn.Module):
@@ -25,7 +24,6 @@
reduce: bool = True,
ignore_nan_grad: bool = True,
):
- assert check_argument_types()
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
@@ -41,11 +39,6 @@
if ignore_nan_grad:
logging.warning("ignore_nan_grad option is not supported for warp_ctc")
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
-
- elif self.ctc_type == "gtnctc":
- from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
-
- self.ctc_loss = GTNCTCLossFunction.apply
else:
raise ValueError(
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
index e5bd640..92c95cc 100644
--- a/funasr/models/data2vec.py
+++ b/funasr/models/data2vec.py
@@ -10,7 +10,6 @@
from typing import Tuple
import torch
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -40,7 +39,6 @@
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
):
- assert check_argument_types()
super().__init__()
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
index 78105ab..0e69c44 100644
--- a/funasr/models/decoder/contextual_decoder.py
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -7,7 +7,6 @@
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-from typeguard import check_argument_types
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.modules.embedding import PositionalEncoding
@@ -126,7 +125,6 @@
kernel_size: int = 21,
sanm_shfit: int = 0,
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
diff --git a/funasr/models/decoder/rnn_decoder.py b/funasr/models/decoder/rnn_decoder.py
index 80709c9..cb119e1 100644
--- a/funasr/models/decoder/rnn_decoder.py
+++ b/funasr/models/decoder/rnn_decoder.py
@@ -3,7 +3,6 @@
import numpy as np
import torch
import torch.nn.functional as F
-from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.nets_utils import to_device
@@ -97,7 +96,6 @@
att_conf: dict = get_default_kwargs(build_attention_list),
):
# FIXME(kamo): The parts of num_spk should be refactored more more more
- assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")
diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
index a0fe9ea..0109cc5 100644
--- a/funasr/models/decoder/rnnt_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -3,7 +3,6 @@
from typing import List, Optional, Tuple
import torch
-from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.specaug.specaug import SpecAug
@@ -38,7 +37,6 @@
"""Construct a RNNDecoder object."""
super().__init__()
- assert check_argument_types()
if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index ed920bf..d83f89f 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -7,7 +7,6 @@
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-from typeguard import check_argument_types
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.modules.embedding import PositionalEncoding
@@ -181,7 +180,6 @@
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
embed_tensor_name_prefix_tf: str = None,
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -838,7 +836,6 @@
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py
index 45fdda8..0a9c612 100644
--- a/funasr/models/decoder/transformer_decoder.py
+++ b/funasr/models/decoder/transformer_decoder.py
@@ -9,7 +9,6 @@
import torch
from torch import nn
-from typeguard import check_argument_types
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
@@ -184,7 +183,6 @@
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
- assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
@@ -373,7 +371,6 @@
normalize_before: bool = True,
concat_after: bool = False,
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -428,7 +425,6 @@
concat_after: bool = False,
embeds_id: int = -1,
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -540,7 +536,6 @@
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
- assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@@ -602,7 +597,6 @@
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
- assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@@ -664,7 +658,6 @@
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
- assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@@ -726,7 +719,6 @@
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
- assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
"conv_kernel_length must have equal number of values to num_blocks: "
@@ -781,7 +773,6 @@
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
- assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
@@ -955,7 +946,6 @@
normalize_before: bool = True,
concat_after: bool = False,
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index e6e6a52..79c5387 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -11,7 +11,6 @@
from typing import Union
import torch
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -65,7 +64,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
diff --git a/funasr/models/e2e_asr_contextual_paraformer.py b/funasr/models/e2e_asr_contextual_paraformer.py
index dc820db..4836663 100644
--- a/funasr/models/e2e_asr_contextual_paraformer.py
+++ b/funasr/models/e2e_asr_contextual_paraformer.py
@@ -9,7 +9,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.ctc import CTC
@@ -43,9 +42,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -72,8 +69,9 @@
crit_attn_weight: float = 0.0,
crit_attn_smooth: float = 0.0,
bias_encoder_dropout_rate: float = 0.0,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index fbf0d11..7dd3b8d 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -7,7 +7,6 @@
from typing import Union
import logging
import torch
-from typeguard import check_argument_types
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
@@ -53,7 +52,7 @@
encoder: AbsEncoder,
decoder: AbsDecoder,
ctc: CTC,
- rnnt_decoder: None,
+ rnnt_decoder: None = None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
@@ -65,7 +64,6 @@
sym_blank: str = "<blank>",
preencoder: Optional[AbsPreEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert rnnt_decoder is None, "Not implemented"
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 686038e..5a1a29b 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -10,7 +10,6 @@
import torch
import random
import numpy as np
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -80,7 +79,6 @@
postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -645,7 +643,6 @@
postencoder: Optional[AbsPostEncoder] = None,
use_1st_decoder_loss: bool = False,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -1255,7 +1252,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -1528,7 +1524,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -1806,7 +1801,6 @@
preencoder: Optional[AbsPreEncoder] = None,
postencoder: Optional[AbsPostEncoder] = None,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index 3f9f31c..80914b1 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -6,8 +6,9 @@
import torch
from packaging.version import parse as V
-from typeguard import check_argument_types
-
+from funasr.losses.label_smoothing_loss import (
+ LabelSmoothingLoss, # noqa: H301
+)
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
@@ -15,6 +16,8 @@
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
+from funasr.modules.nets_utils import th_accuracy
+from funasr.modules.add_sos_eos import add_sos_eos
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
@@ -81,8 +84,6 @@
) -> None:
"""Construct an ESPnetASRTransducerModel object."""
super().__init__()
-
- assert check_argument_types()
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
self.blank_id = 0
@@ -542,8 +543,6 @@
"""Construct an ESPnetASRTransducerModel object."""
super().__init__()
- assert check_argument_types()
-
# The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
self.blank_id = 0
@@ -709,7 +708,7 @@
loss_lm = self._calc_lm_loss(decoder_out, target)
loss_trans = loss_trans_utt + loss_trans_chunk
- loss_ctc = loss_ctc + loss_ctc_chunk
+ loss_ctc = loss_ctc + loss_ctc_chunk
loss_ctc = loss_att + loss_att_chunk
loss = (
@@ -1014,4 +1013,4 @@
ignore_label=self.ignore_id,
)
- return loss_att, acc_att
+ return loss_att, acc_att
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index da7c674..ae3a436 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -9,7 +9,6 @@
import numpy as np
import torch
import torch.nn as nn
-from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
@@ -48,7 +47,6 @@
mapping_dict=None,
**kwargs,
):
- assert check_argument_types()
super().__init__()
self.frontend = frontend
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 9c3fb92..bc93b9d 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -12,7 +12,6 @@
import numpy as np
import torch
from torch.nn import functional as F
-from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -66,7 +65,6 @@
inter_score_loss_weight: float = 0.0,
inputs_type: str = "raw",
):
- assert check_argument_types()
super().__init__()
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index 8304607..cf1587d 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -12,7 +12,6 @@
import torch
import torch.nn.functional as F
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -40,7 +39,7 @@
yield
-class ESPnetASRModel(FunASRModel):
+class SAASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -51,10 +50,8 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
asr_encoder: AbsEncoder,
spk_encoder: torch.nn.Module,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
spk_weight: float = 0.5,
@@ -69,7 +66,6 @@
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -89,8 +85,6 @@
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
self.asr_encoder = asr_encoder
self.spk_encoder = spk_encoder
@@ -293,10 +287,6 @@
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +307,6 @@
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
else:
encoder_out_spk=encoder_out_spk_ori
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -337,7 +322,7 @@
)
if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
+ return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
return encoder_out, encoder_out_lens, encoder_out_spk
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index bd5178e..8be63d4 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -12,7 +12,6 @@
from typing import Union
import torch
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -56,7 +55,6 @@
pooling_layer: torch.nn.Module,
decoder: AbsDecoder,
):
- assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index 33948f9..567dc70 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -9,7 +9,6 @@
import torch
import numpy as np
-from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
@@ -42,7 +41,6 @@
predictor_bias: int = 0,
token_list=None,
):
- assert check_argument_types()
super().__init__()
# note that eos is the same as sos (equivalent ID)
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index d08ea37..8bc3b42 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -8,7 +8,6 @@
from typing import Union
import torch
-from typeguard import check_argument_types
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
@@ -50,9 +49,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -80,9 +77,10 @@
loss_weight_model1: float = 0.5,
enable_maas_finetune: bool = False,
freeze_encoder2: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
encoder1_encoder2_joint_training: bool = True,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 71ed2cf..7c55b2e 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.base_model import FunASRModel
class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
return int(self.frame_size_ms)
-class E2EVadModel(nn.Module):
+class E2EVadModel(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -252,8 +253,8 @@
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.ResetDetection()
self.frontend = frontend
+ self.last_drop_frames = 0
def AllResetDetection(self):
self.data_buf_start_frame = 0
@@ -282,7 +283,8 @@
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.ResetDetection()
+ self.last_drop_frames = 0
+ self.windows_detector.Reset()
def ResetDetection(self):
self.continous_silence_frame_count = 0
@@ -294,6 +296,15 @@
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
+
+ if self.output_data_buf:
+ assert self.output_data_buf[-1].contain_seg_end_point == True
+ drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+ real_drop_frames = drop_frames - self.last_drop_frames
+ self.last_drop_frames = drop_frames
+ self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+ self.decibel = self.decibel[real_drop_frames:]
+ self.scores = self.scores[:, real_drop_frames:, :]
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
@@ -322,7 +333,7 @@
while self.data_buf_start_frame < frame_idx:
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
- self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
+ self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
@@ -543,7 +554,7 @@
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
- frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
return 0
@@ -553,7 +564,7 @@
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
- frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
else:
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 5f20dee..e5fac62 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -12,7 +12,6 @@
import torch
from torch import nn
-from typeguard import check_argument_types
from funasr.models.ctc import CTC
from funasr.modules.attention import (
@@ -533,7 +532,6 @@
interctc_use_conditioning: bool = False,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -943,7 +941,6 @@
"""Construct an Encoder object."""
super().__init__()
- assert check_argument_types()
self.embed = StreamingConvInput(
input_size,
@@ -1081,7 +1078,10 @@
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ if self.training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ else:
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
chunk_mask = make_chunk_mask(
@@ -1113,12 +1113,15 @@
elif self.dynamic_chunk_training:
max_len = x.size(1)
- chunk_size = torch.randint(1, max_len, (1,)).item()
+ if self.training:
+ chunk_size = torch.randint(1, max_len, (1,)).item()
- if chunk_size > (max_len * self.short_chunk_threshold):
- chunk_size = max_len
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
else:
- chunk_size = (chunk_size % self.short_chunk_size) + 1
+ chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
pos_enc = self.pos_enc(x)
@@ -1147,6 +1150,45 @@
return x, olens, None
+ def full_utt_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ return x_utt
+
def simu_chunk_forward(
self,
x: torch.Tensor,
diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py
index 64c2144..8885f02 100644
--- a/funasr/models/encoder/data2vec_encoder.py
+++ b/funasr/models/encoder/data2vec_encoder.py
@@ -10,7 +10,6 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
-from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.data2vec.data_utils import compute_mask_indices
@@ -97,7 +96,6 @@
# FP16 optimization
required_seq_len_multiple: int = 2,
):
- assert check_argument_types()
super().__init__()
# ConvFeatureExtractionModel
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 95ccf07..87bb19d 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -5,7 +5,6 @@
import torch
from torch import nn
-from typeguard import check_argument_types
from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
from funasr.modules.nets_utils import get_activation
@@ -161,7 +160,6 @@
cnn_module_kernel: int = 31,
padding_idx: int = -1,
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
diff --git a/funasr/models/encoder/opennmt_encoders/conv_encoder.py b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
index eec854f..9ab5e6b 100644
--- a/funasr/models/encoder/opennmt_encoders/conv_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
@@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
-from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.layer_norm import LayerNorm
@@ -90,7 +89,6 @@
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = num_units
diff --git a/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py b/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
index e41b2aa..5f62e67 100644
--- a/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
@@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
-from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.layer_norm import LayerNorm
diff --git a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
index db30f08..7c83cbd 100644
--- a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
@@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
@@ -144,7 +143,6 @@
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
out_units=None,
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py
index 59730da..434af09 100644
--- a/funasr/models/encoder/rnn_encoder.py
+++ b/funasr/models/encoder/rnn_encoder.py
@@ -5,7 +5,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.rnn.encoders import RNN
@@ -37,7 +36,6 @@
dropout: float = 0.0,
subsample: Optional[Sequence[int]] = (2, 2, 1, 1),
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
self.rnn_type = rnn_type
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 46eabd1..45163df 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -8,7 +8,6 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
import numpy as np
from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -151,7 +150,6 @@
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -601,7 +599,6 @@
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -1060,7 +1057,6 @@
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
diff --git a/funasr/models/encoder/transformer_encoder.py b/funasr/models/encoder/transformer_encoder.py
index ff9c3db..4f2bef5 100644
--- a/funasr/models/encoder/transformer_encoder.py
+++ b/funasr/models/encoder/transformer_encoder.py
@@ -9,7 +9,6 @@
import torch
from torch import nn
-from typeguard import check_argument_types
import logging
from funasr.models.ctc import CTC
@@ -189,7 +188,6 @@
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 19994f0..b41af80 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -2,18 +2,18 @@
from typing import Optional
from typing import Tuple
from typing import Union
-
+import logging
import humanfriendly
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
-from typeguard import check_argument_types
from funasr.layers.log_mel import LogMel
from funasr.layers.stft import Stft
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.modules.nets_utils import make_pad_mask
class DefaultFrontend(AbsFrontend):
@@ -39,7 +39,6 @@
apply_stft: bool = True,
use_channel: int = None,
):
- assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
@@ -76,8 +75,8 @@
htk=htk,
)
self.n_mels = n_mels
- self.frontend_type = "default"
self.use_channel = use_channel
+ self.frontend_type = "default"
def output_size(self) -> int:
return self.n_mels
@@ -137,8 +136,6 @@
return input_stft, feats_lens
-
-
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -149,7 +146,9 @@
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
- hop_length: int = 128,
+ hop_length: int = None,
+ frame_length: int = None,
+ frame_shift: int = None,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
@@ -160,25 +159,36 @@
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
- frame_length: int = None,
- frame_shift: int = None,
- lfr_m: int = None,
- lfr_n: int = None,
+ use_channel: int = None,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ cmvn_file: str = None,
+ mc: bool = True
):
- assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
- self.hop_length = hop_length
+ if win_length is None and hop_length is None:
+ self.win_length = frame_length * 16
+ self.hop_length = frame_shift * 16
+ elif frame_length is None and frame_shift is None:
+ self.win_length = self.win_length
+ self.hop_length = self.hop_length
+ else:
+ logging.error(
+ "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
+ "can be set."
+ )
+ exit(1)
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
center=center,
window=window,
normalized=normalized,
@@ -202,6 +212,18 @@
htk=htk,
)
self.n_mels = n_mels
+ self.use_channel = use_channel
+ self.mc = mc
+ if not self.mc:
+ if self.use_channel is not None:
+ logging.info("use the channel %d" % (self.use_channel))
+ else:
+ logging.info("random select channel")
+ self.cmvn_file = cmvn_file
+ if self.cmvn_file is not None:
+ mean, std = self._load_cmvn(self.cmvn_file)
+ self.register_buffer("mean", torch.from_numpy(mean))
+ self.register_buffer("std", torch.from_numpy(std))
self.frontend_type = "multichannelfrontend"
def output_size(self) -> int:
@@ -215,16 +237,29 @@
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
- if isinstance(input, ComplexTensor):
- input_stft = input
- else:
- input_stft = ComplexTensor(input[..., 0], input[..., 1])
+ input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
+
+ # 3. [Multi channel case]: Select a channel(sa_asr)
+ if input_stft.dim() == 4 and not self.mc:
+ # h: (B, T, C, F) -> h: (B, T, F)
+ if self.training:
+ if self.use_channel is not None:
+ input_stft = input_stft[:, :, self.use_channel, :]
+
+ else:
+ # Select 1ch randomly
+ ch = np.random.randint(input_stft.size(2))
+ input_stft = input_stft[:, :, ch, :]
+ else:
+ # Use the first channel
+ input_stft = input_stft[:, :, 0, :]
+
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
@@ -233,18 +268,37 @@
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
- bt = input_feats.size(0)
- if input_feats.dim() ==4:
- channel_size = input_feats.size(2)
- # batch * channel * T * D
- #pdb.set_trace()
- input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
- # input_feats = input_feats.transpose(1,2)
- # batch * channel
- feats_lens = feats_lens.repeat(1,channel_size).squeeze()
+ if self.mc:
+ # MFCCA
+ if input_feats.dim() ==4:
+ bt = input_feats.size(0)
+ channel_size = input_feats.size(2)
+ input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
+ feats_lens = feats_lens.repeat(1,channel_size).squeeze()
+ else:
+ channel_size = 1
+ return input_feats, feats_lens, channel_size
else:
- channel_size = 1
- return input_feats, feats_lens, channel_size
+ # 6. Apply CMVN
+ if self.cmvn_file is not None:
+ if feats_lens is None:
+ feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
+ self.mean = self.mean.to(input_feats.device, input_feats.dtype)
+ self.std = self.std.to(input_feats.device, input_feats.dtype)
+ mask = make_pad_mask(feats_lens, input_feats, 1)
+
+ if input_feats.requires_grad:
+ input_feats = input_feats + self.mean
+ else:
+ input_feats += self.mean
+ if input_feats.requires_grad:
+ input_feats = input_feats.masked_fill(mask, 0.0)
+ else:
+ input_feats.masked_fill_(mask, 0.0)
+
+ input_feats *= self.std
+
+ return input_feats, feats_lens
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
@@ -258,4 +312,27 @@
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
- return input_stft, feats_lens
\ No newline at end of file
+ return input_stft, feats_lens
+
+ def _load_cmvn(self, cmvn_file):
+ with open(cmvn_file, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ means_list = []
+ vars_list = []
+ for i in range(len(lines)):
+ line_item = lines[i].split()
+ if line_item[0] == '<AddShift>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ add_shift_line = line_item[3:(len(line_item) - 1)]
+ means_list = list(add_shift_line)
+ continue
+ elif line_item[0] == '<Rescale>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ rescale_line = line_item[3:(len(line_item) - 1)]
+ vars_list = list(rescale_line)
+ continue
+ means = np.array(means_list).astype(np.float)
+ vars = np.array(vars_list).astype(np.float)
+ return means, vars
diff --git a/funasr/models/frontend/fused.py b/funasr/models/frontend/fused.py
index 857486d..ff95871 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/models/frontend/fused.py
@@ -3,7 +3,6 @@
from funasr.models.frontend.s3prl import S3prlFrontend
import numpy as np
import torch
-from typeguard import check_argument_types
from typing import Tuple
@@ -12,7 +11,6 @@
self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
):
- assert check_argument_types()
super().__init__()
self.align_method = (
align_method # fusing method : linear_projection only for now
diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py
index b03d2c9..fdeb1c5 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/models/frontend/s3prl.py
@@ -8,7 +8,6 @@
import humanfriendly
import torch
-from typeguard import check_argument_types
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
@@ -37,7 +36,6 @@
download_dir: str = None,
multilayer_feature: bool = False,
):
- assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index f16bdd9..acab13b 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -6,7 +6,6 @@
import torch
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
-from typeguard import check_argument_types
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.models.frontend.abs_frontend import AbsFrontend
@@ -95,7 +94,6 @@
snip_edges: bool = True,
upsacle_samples: bool = True,
):
- assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
@@ -227,7 +225,6 @@
snip_edges: bool = True,
upsacle_samples: bool = True,
):
- assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
@@ -466,7 +463,6 @@
lfr_m: int = 1,
lfr_n: int = 1,
):
- assert check_argument_types()
super().__init__()
self.fs = fs
self.frame_length = frame_length
diff --git a/funasr/models/frontend/windowing.py b/funasr/models/frontend/windowing.py
index a526758..94c9d27 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/models/frontend/windowing.py
@@ -6,7 +6,6 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
import torch
-from typeguard import check_argument_types
from typing import Tuple
@@ -38,7 +37,6 @@
padding: Padding (placeholder, currently not implemented).
fs: Sampling rate (placeholder for compatibility, not used).
"""
- assert check_argument_types()
super().__init__()
self.fs = fs
self.win_length = win_length
diff --git a/funasr/models/postencoder/hugging_face_transformers_postencoder.py b/funasr/models/postencoder/hugging_face_transformers_postencoder.py
index 1aad15d..c59e2b7 100644
--- a/funasr/models/postencoder/hugging_face_transformers_postencoder.py
+++ b/funasr/models/postencoder/hugging_face_transformers_postencoder.py
@@ -6,7 +6,6 @@
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from typeguard import check_argument_types
from typing import Tuple
import copy
@@ -30,7 +29,6 @@
model_name_or_path: str,
):
"""Initialize the module."""
- assert check_argument_types()
super().__init__()
if not is_transformers_available:
diff --git a/funasr/models/preencoder/linear.py b/funasr/models/preencoder/linear.py
index c69b6ce..25f6720 100644
--- a/funasr/models/preencoder/linear.py
+++ b/funasr/models/preencoder/linear.py
@@ -5,7 +5,6 @@
"""Linear Projection."""
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from typeguard import check_argument_types
from typing import Tuple
import torch
@@ -20,7 +19,6 @@
output_size: int,
):
"""Initialize the module."""
- assert check_argument_types()
super().__init__()
self.output_dim = output_size
diff --git a/funasr/models/preencoder/sinc.py b/funasr/models/preencoder/sinc.py
index fe6d2af..3baa4a8 100644
--- a/funasr/models/preencoder/sinc.py
+++ b/funasr/models/preencoder/sinc.py
@@ -10,7 +10,6 @@
from funasr.layers.sinc_conv import SincConv
import humanfriendly
import torch
-from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from typing import Union
@@ -60,7 +59,6 @@
windowing_type: Choice of windowing function.
scale_type: Choice of filter-bank initialization scale.
"""
- assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
@@ -268,7 +266,6 @@
dropout_probability: Dropout probability.
shape (tuple, list): Shape of input tensors.
"""
- assert check_argument_types()
super().__init__()
if shape is None:
shape = (0, 2, 1)
diff --git a/funasr/models/seq_rnn_lm.py b/funasr/models/seq_rnn_lm.py
index f7ddcae..bef4974 100644
--- a/funasr/models/seq_rnn_lm.py
+++ b/funasr/models/seq_rnn_lm.py
@@ -4,7 +4,6 @@
import torch
import torch.nn as nn
-from typeguard import check_argument_types
from funasr.train.abs_model import AbsLM
@@ -27,7 +26,6 @@
rnn_type: str = "lstm",
ignore_id: int = 0,
):
- assert check_argument_types()
super().__init__()
ninp = unit
diff --git a/funasr/modules/eend_ola/utils/__init__.py b/funasr/modules/eend_ola/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/modules/eend_ola/utils/__init__.py
diff --git a/funasr/modules/eend_ola/utils/report.py b/funasr/modules/eend_ola/utils/report.py
index bfccedf..f4a044b 100644
--- a/funasr/modules/eend_ola/utils/report.py
+++ b/funasr/modules/eend_ola/utils/report.py
@@ -2,7 +2,7 @@
import numpy as np
import time
import torch
-from eend.utils.power import create_powerlabel
+from funasr.modules.eend_ola.utils.power import create_powerlabel
from itertools import combinations
metrics = [
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index a2b91a7..77aa422 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -427,6 +427,7 @@
conv_size: Union[int, Tuple],
subsampling_factor: int = 4,
vgg_like: bool = True,
+ conv_kernel_size: int = 3,
output_size: Optional[int] = None,
) -> None:
"""Construct a ConvInput object."""
@@ -436,14 +437,14 @@
conv_size1, conv_size2 = conv_size
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((1, 2)),
)
@@ -462,14 +463,14 @@
kernel_1 = int(subsampling_factor / 2)
self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((kernel_1, 2)),
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
@@ -487,14 +488,14 @@
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
torch.nn.ReLU(),
- torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+ torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
torch.nn.ReLU(),
)
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
self.subsampling_factor = subsampling_factor
- self.kernel_2 = 3
+ self.kernel_2 = conv_kernel_size
self.stride_2 = 1
self.create_new_mask = self.create_new_conv2d_mask
diff --git a/funasr/optimizers/sgd.py b/funasr/optimizers/sgd.py
index 3f0d3d1..fb7a3df 100644
--- a/funasr/optimizers/sgd.py
+++ b/funasr/optimizers/sgd.py
@@ -1,5 +1,4 @@
import torch
-from typeguard import check_argument_types
class SGD(torch.optim.SGD):
@@ -21,7 +20,6 @@
weight_decay: float = 0.0,
nesterov: bool = False,
):
- assert check_argument_types()
super().__init__(
params,
lr=lr,
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj
new file mode 100644
index 0000000..b494bb5
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj
@@ -0,0 +1,18 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <OutputType>Exe</OutputType>
+ <TargetFramework>net6.0</TargetFramework>
+ <ImplicitUsings>enable</ImplicitUsings>
+ <Nullable>enable</Nullable>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <PackageReference Include="NAudio" Version="2.1.0" />
+ </ItemGroup>
+
+ <ItemGroup>
+ <ProjectReference Include="..\AliFsmnVadSharp\AliFsmnVadSharp.csproj" />
+ </ItemGroup>
+
+</Project>
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs
new file mode 100644
index 0000000..dd3bf78
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp.Examples/Program.cs
@@ -0,0 +1,61 @@
+锘縰sing AliFsmnVadSharp;
+using AliFsmnVadSharp.Model;
+using NAudio.Wave;
+
+internal static class Program
+{
+ [STAThread]
+ private static void Main()
+ {
+ string applicationBase = AppDomain.CurrentDomain.BaseDirectory;
+ string modelFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx";
+ string configFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.yaml";
+ string mvnFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.mvn";
+ int batchSize = 2;
+ TimeSpan start_time0 = new TimeSpan(DateTime.Now.Ticks);
+ AliFsmnVad aliFsmnVad = new AliFsmnVad(modelFilePath, configFilePath, mvnFilePath, batchSize);
+ TimeSpan end_time0 = new TimeSpan(DateTime.Now.Ticks);
+ double elapsed_milliseconds0 = end_time0.TotalMilliseconds - start_time0.TotalMilliseconds;
+ Console.WriteLine("load model and init config elapsed_milliseconds:{0}", elapsed_milliseconds0.ToString());
+ List<float[]> samples = new List<float[]>();
+ TimeSpan total_duration = new TimeSpan(0L);
+ for (int i = 0; i < 2; i++)
+ {
+ string wavFilePath = string.Format(applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/example/{0}.wav", i.ToString());//vad_example
+ if (!File.Exists(wavFilePath))
+ {
+ continue;
+ }
+ AudioFileReader _audioFileReader = new AudioFileReader(wavFilePath);
+ byte[] datas = new byte[_audioFileReader.Length];
+ _audioFileReader.Read(datas, 0, datas.Length);
+ TimeSpan duration = _audioFileReader.TotalTime;
+ float[] wavdata = new float[datas.Length / 4];
+ Buffer.BlockCopy(datas, 0, wavdata, 0, datas.Length);
+ float[] sample = wavdata.Select((float x) => x * 32768f).ToArray();
+ samples.Add(wavdata);
+ total_duration += duration;
+ }
+ TimeSpan start_time = new TimeSpan(DateTime.Now.Ticks);
+ //SegmentEntity[] segments_duration = aliFsmnVad.GetSegments(samples);
+ SegmentEntity[] segments_duration = aliFsmnVad.GetSegmentsByStep(samples);
+ TimeSpan end_time = new TimeSpan(DateTime.Now.Ticks);
+ Console.WriteLine("vad infer result:");
+ foreach (SegmentEntity segment in segments_duration)
+ {
+ Console.Write("[");
+ foreach (var x in segment.Segment)
+ {
+ Console.Write("[" + string.Join(",", x.ToArray()) + "]");
+ }
+ Console.Write("]\r\n");
+ }
+
+ double elapsed_milliseconds = end_time.TotalMilliseconds - start_time.TotalMilliseconds;
+ double rtf = elapsed_milliseconds / total_duration.TotalMilliseconds;
+ Console.WriteLine("elapsed_milliseconds:{0}", elapsed_milliseconds.ToString());
+ Console.WriteLine("total_duration:{0}", total_duration.TotalMilliseconds.ToString());
+ Console.WriteLine("rtf:{1}", "0".ToString(), rtf.ToString());
+ Console.WriteLine("------------------------");
+ }
+}
\ No newline at end of file
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp.sln b/funasr/runtime/csharp/AliFsmnVadSharp.sln
new file mode 100644
index 0000000..8bf24aa
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp.sln
@@ -0,0 +1,37 @@
+锘�
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 17
+VisualStudioVersion = 17.1.32210.238
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AliFsmnVadSharp", "AliFsmnVadSharp\AliFsmnVadSharp.csproj", "{BFB82F2E-AD5B-405C-AAFF-3CE33C548748}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AliFsmnVadSharp.Examples", "AliFsmnVadSharp.Examples\AliFsmnVadSharp.Examples.csproj", "{2FFA4D03-A62B-435B-B57B-7E49209810E1}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{212561CC-9836-4F45-A31B-298EF576F519}"
+ ProjectSection(SolutionItems) = preProject
+ license = license
+ README.md = README.md
+ EndProjectSection
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {BFB82F2E-AD5B-405C-AAFF-3CE33C548748}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {2FFA4D03-A62B-435B-B57B-7E49209810E1}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {FCC1BBCC-91A3-4223-B368-D272FB5108B6}
+ EndGlobalSection
+EndGlobal
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs
new file mode 100644
index 0000000..f42bfb1
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVad.cs
@@ -0,0 +1,387 @@
+锘縰sing System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.ML;
+using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.OnnxRuntime.Tensors;
+using Microsoft.Extensions.Logging;
+using AliFsmnVadSharp.Model;
+using AliFsmnVadSharp.Utils;
+
+namespace AliFsmnVadSharp
+{
+ public class AliFsmnVad
+ {
+ private InferenceSession _onnxSession;
+ private readonly ILogger<AliFsmnVad> _logger;
+ private string _frontend;
+ private WavFrontend _wavFrontend;
+ private int _batchSize = 1;
+ private int _max_end_sil = int.MinValue;
+ private EncoderConfEntity _encoderConfEntity;
+ private VadPostConfEntity _vad_post_conf;
+
+ public AliFsmnVad(string modelFilePath, string configFilePath, string mvnFilePath, int batchSize = 1)
+ {
+ Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
+ options.AppendExecutionProvider_CPU(0);
+ options.InterOpNumThreads = 1;
+ _onnxSession = new InferenceSession(modelFilePath, options);
+
+ VadYamlEntity vadYamlEntity = YamlHelper.ReadYaml<VadYamlEntity>(configFilePath);
+ _wavFrontend = new WavFrontend(mvnFilePath, vadYamlEntity.frontend_conf);
+ _frontend = vadYamlEntity.frontend;
+ _vad_post_conf = vadYamlEntity.vad_post_conf;
+ _batchSize = batchSize;
+ _max_end_sil = _max_end_sil != int.MinValue ? _max_end_sil : vadYamlEntity.vad_post_conf.max_end_silence_time;
+ _encoderConfEntity = vadYamlEntity.encoder_conf;
+
+ ILoggerFactory loggerFactory = new LoggerFactory();
+ _logger = new Logger<AliFsmnVad>(loggerFactory);
+ }
+
+ public SegmentEntity[] GetSegments(List<float[]> samples)
+ {
+ int waveform_nums = samples.Count;
+ _batchSize = Math.Min(waveform_nums, _batchSize);
+ SegmentEntity[] segments = new SegmentEntity[waveform_nums];
+ for (int beg_idx = 0; beg_idx < waveform_nums; beg_idx += _batchSize)
+ {
+ int end_idx = Math.Min(waveform_nums, beg_idx + _batchSize);
+ List<float[]> waveform_list = new List<float[]>();
+ for (int i = beg_idx; i < end_idx; i++)
+ {
+ waveform_list.Add(samples[i]);
+ }
+ List<VadInputEntity> vadInputEntitys = ExtractFeats(waveform_list);
+ try
+ {
+ int t_offset = 0;
+ int step = Math.Min(waveform_list.Max(x => x.Length), 6000);
+ bool is_final = true;
+ List<VadOutputEntity> vadOutputEntitys = Infer(vadInputEntitys);
+ for (int batch_num = beg_idx; batch_num < end_idx; batch_num++)
+ {
+ var scores = vadOutputEntitys[batch_num - beg_idx].Scores;
+ SegmentEntity[] segments_part = vadInputEntitys[batch_num].VadScorer.DefaultCall(scores, waveform_list[batch_num - beg_idx], is_final: is_final, max_end_sil: _max_end_sil, online: false);
+ if (segments_part.Length > 0)
+ {
+#pragma warning disable CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+ if (segments[batch_num] == null)
+ {
+ segments[batch_num] = new SegmentEntity();
+ }
+ segments[batch_num].Segment.AddRange(segments_part[0].Segment); //
+#pragma warning restore CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+
+ }
+ }
+ }
+ catch (OnnxRuntimeException ex)
+ {
+ _logger.LogWarning("input wav is silence or noise");
+ segments = null;
+ }
+// for (int batch_num = 0; batch_num < _batchSize; batch_num++)
+// {
+// List<float[]> segment_waveforms = new List<float[]>();
+// foreach (int[] segment in segments[beg_idx + batch_num].Segment)
+// {
+// // (int)(16000 * (segment[0] / 1000.0) * 2);
+// int frame_length = (((6000 * 400) / 400 - 1) * 160 + 400) / 60 / 1000;
+// int frame_start = segment[0] * frame_length;
+// int frame_end = segment[1] * frame_length;
+// float[] segment_waveform = new float[frame_end - frame_start];
+// Array.Copy(waveform_list[batch_num], frame_start, segment_waveform, 0, segment_waveform.Length);
+// segment_waveforms.Add(segment_waveform);
+// }
+// segments[beg_idx + batch_num].Waveform.AddRange(segment_waveforms);
+// }
+ }
+
+ return segments;
+ }
+
+ public SegmentEntity[] GetSegmentsByStep(List<float[]> samples)
+ {
+ int waveform_nums = samples.Count;
+ _batchSize=Math.Min(waveform_nums, _batchSize);
+ SegmentEntity[] segments = new SegmentEntity[waveform_nums];
+ for (int beg_idx = 0; beg_idx < waveform_nums; beg_idx += _batchSize)
+ {
+ int end_idx = Math.Min(waveform_nums, beg_idx + _batchSize);
+ List<float[]> waveform_list = new List<float[]>();
+ for (int i = beg_idx; i < end_idx; i++)
+ {
+ waveform_list.Add(samples[i]);
+ }
+ List<VadInputEntity> vadInputEntitys = ExtractFeats(waveform_list);
+ int feats_len = vadInputEntitys.Max(x => x.SpeechLength);
+ List<float[]> in_cache = new List<float[]>();
+ in_cache = PrepareCache(in_cache);
+ try
+ {
+ int step = Math.Min(vadInputEntitys.Max(x => x.SpeechLength), 6000 * 400);
+ bool is_final = true;
+ for (int t_offset = 0; t_offset < (int)(feats_len); t_offset += Math.Min(step, feats_len - t_offset))
+ {
+
+ if (t_offset + step >= feats_len - 1)
+ {
+ step = feats_len - t_offset;
+ is_final = true;
+ }
+ else
+ {
+ is_final = false;
+ }
+ List<VadInputEntity> vadInputEntitys_step = new List<VadInputEntity>();
+ foreach (VadInputEntity vadInputEntity in vadInputEntitys)
+ {
+ VadInputEntity vadInputEntity_step = new VadInputEntity();
+ float[]? feats = vadInputEntity.Speech;
+ int curr_step = Math.Min(feats.Length - t_offset, step);
+ if (curr_step <= 0)
+ {
+ vadInputEntity_step.Speech = new float[32000];
+ vadInputEntity_step.SpeechLength = 0;
+ vadInputEntity_step.InCaches = in_cache;
+ vadInputEntity_step.Waveform = new float[(((int)(32000) / 400 - 1) * 160 + 400)];
+ vadInputEntitys_step.Add(vadInputEntity_step);
+ continue;
+ }
+ float[]? feats_step = new float[curr_step];
+ Array.Copy(feats, t_offset, feats_step, 0, feats_step.Length);
+ float[]? waveform = vadInputEntity.Waveform;
+ float[]? waveform_step = new float[Math.Min(waveform.Length, ((int)(t_offset + step) / 400 - 1) * 160 + 400) - t_offset / 400 * 160];
+ Array.Copy(waveform, t_offset / 400 * 160, waveform_step, 0, waveform_step.Length);
+ vadInputEntity_step.Speech = feats_step;
+ vadInputEntity_step.SpeechLength = feats_step.Length;
+ vadInputEntity_step.InCaches = vadInputEntity.InCaches;
+ vadInputEntity_step.Waveform = waveform_step;
+ vadInputEntitys_step.Add(vadInputEntity_step);
+ }
+ List<VadOutputEntity> vadOutputEntitys = Infer(vadInputEntitys_step);
+ for (int batch_num = 0; batch_num < _batchSize; batch_num++)
+ {
+ vadInputEntitys[batch_num].InCaches = vadOutputEntitys[batch_num].OutCaches;
+ var scores = vadOutputEntitys[batch_num].Scores;
+ SegmentEntity[] segments_part = vadInputEntitys[batch_num].VadScorer.DefaultCall(scores, vadInputEntitys_step[batch_num].Waveform, is_final: is_final, max_end_sil: _max_end_sil, online: false);
+ if (segments_part.Length > 0)
+ {
+
+#pragma warning disable CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+ if (segments[beg_idx + batch_num] == null)
+ {
+ segments[beg_idx + batch_num] = new SegmentEntity();
+ }
+ if (segments_part[0] != null)
+ {
+ segments[beg_idx + batch_num].Segment.AddRange(segments_part[0].Segment);
+ }
+#pragma warning restore CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+
+ }
+ }
+ }
+ }
+ catch (OnnxRuntimeException ex)
+ {
+ _logger.LogWarning("input wav is silence or noise");
+ segments = null;
+ }
+// for (int batch_num = 0; batch_num < _batchSize; batch_num++)
+// {
+// List<float[]> segment_waveforms=new List<float[]>();
+// foreach (int[] segment in segments[beg_idx + batch_num].Segment)
+// {
+// // (int)(16000 * (segment[0] / 1000.0) * 2);
+// int frame_length = (((6000 * 400) / 400 - 1) * 160 + 400) / 60 / 1000;
+// int frame_start = segment[0] * frame_length;
+// int frame_end = segment[1] * frame_length;
+// if(frame_end > waveform_list[batch_num].Length)
+// {
+// break;
+// }
+// float[] segment_waveform = new float[frame_end - frame_start];
+// Array.Copy(waveform_list[batch_num], frame_start, segment_waveform, 0, segment_waveform.Length);
+// segment_waveforms.Add(segment_waveform);
+// }
+// segments[beg_idx + batch_num].Waveform.AddRange(segment_waveforms);
+// }
+
+ }
+ return segments;
+ }
+
+ private List<float[]> PrepareCache(List<float[]> in_cache)
+ {
+ if (in_cache.Count > 0)
+ {
+ return in_cache;
+ }
+
+ int fsmn_layers = _encoderConfEntity.fsmn_layers;
+
+ int proj_dim = _encoderConfEntity.proj_dim;
+ int lorder = _encoderConfEntity.lorder;
+
+ for (int i = 0; i < fsmn_layers; i++)
+ {
+ float[] cache = new float[1 * proj_dim * (lorder - 1) * 1];
+ in_cache.Add(cache);
+ }
+ return in_cache;
+ }
+
+ private List<VadInputEntity> ExtractFeats(List<float[]> waveform_list)
+ {
+ List<float[]> in_cache = new List<float[]>();
+ in_cache = PrepareCache(in_cache);
+ List<VadInputEntity> vadInputEntitys = new List<VadInputEntity>();
+ foreach (var waveform in waveform_list)
+ {
+ float[] fbanks = _wavFrontend.GetFbank(waveform);
+ float[] features = _wavFrontend.LfrCmvn(fbanks);
+ VadInputEntity vadInputEntity = new VadInputEntity();
+ vadInputEntity.Waveform = waveform;
+ vadInputEntity.Speech = features;
+ vadInputEntity.SpeechLength = features.Length;
+ vadInputEntity.InCaches = in_cache;
+ vadInputEntity.VadScorer = new E2EVadModel(_vad_post_conf);
+ vadInputEntitys.Add(vadInputEntity);
+ }
+ return vadInputEntitys;
+ }
+ /// <summary>
+ /// 涓�缁存暟缁勮浆3缁存暟缁�
+ /// </summary>
+ /// <param name="obj"></param>
+ /// <param name="len">涓�缁撮暱</param>
+ /// <param name="wid">浜岀淮闀�</param>
+ /// <returns></returns>
+ public static T[,,] DimOneToThree<T>(T[] oneDimObj, int len, int wid)
+ {
+ if (oneDimObj.Length % (len * wid) != 0)
+ return null;
+ int height = oneDimObj.Length / (len * wid);
+ T[,,] threeDimObj = new T[len, wid, height];
+
+ for (int i = 0; i < oneDimObj.Length; i++)
+ {
+ threeDimObj[i / (wid * height), (i / height) % wid, i % height] = oneDimObj[i];
+ }
+ return threeDimObj;
+ }
+
+ private List<VadOutputEntity> Infer(List<VadInputEntity> vadInputEntitys)
+ {
+ List<VadOutputEntity> vadOutputEntities = new List<VadOutputEntity>();
+ foreach (VadInputEntity vadInputEntity in vadInputEntitys)
+ {
+ int batchSize = 1;//_batchSize
+ var inputMeta = _onnxSession.InputMetadata;
+ var container = new List<NamedOnnxValue>();
+ int[] dim = new int[] { batchSize, vadInputEntity.Speech.Length / 400 / batchSize, 400 };
+ var tensor = new DenseTensor<float>(vadInputEntity.Speech, dim, false);
+ container.Add(NamedOnnxValue.CreateFromTensor<float>("speech", tensor));
+
+ int i = 0;
+ foreach (var cache in vadInputEntity.InCaches)
+ {
+ int[] cache_dim = new int[] { 1, 128, cache.Length / 128 / 1, 1 };
+ var cache_tensor = new DenseTensor<float>(cache, cache_dim, false);
+ container.Add(NamedOnnxValue.CreateFromTensor<float>("in_cache" + i.ToString(), cache_tensor));
+ i++;
+ }
+
+ IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = _onnxSession.Run(container);
+ var resultsArray = results.ToArray();
+ VadOutputEntity vadOutputEntity = new VadOutputEntity();
+ for (int j = 0; j < resultsArray.Length; j++)
+ {
+ if (resultsArray[j].Name.Equals("logits"))
+ {
+ Tensor<float> tensors = resultsArray[0].AsTensor<float>();
+ var _scores = DimOneToThree<float>(tensors.ToArray(), 1, tensors.Dimensions[1]);
+ vadOutputEntity.Scores = _scores;
+ }
+ if (resultsArray[j].Name.StartsWith("out_cache"))
+ {
+ vadOutputEntity.OutCaches.Add(resultsArray[j].AsEnumerable<float>().ToArray());
+ }
+
+ }
+ vadOutputEntities.Add(vadOutputEntity);
+ }
+
+ return vadOutputEntities;
+ }
+
+ private float[] PadSequence(List<VadInputEntity> modelInputs)
+ {
+ int max_speech_length = modelInputs.Max(x => x.SpeechLength);
+ int speech_length = max_speech_length * modelInputs.Count;
+ float[] speech = new float[speech_length];
+ float[,] xxx = new float[modelInputs.Count, max_speech_length];
+ for (int i = 0; i < modelInputs.Count; i++)
+ {
+ if (max_speech_length == modelInputs[i].SpeechLength)
+ {
+ for (int j = 0; j < xxx.GetLength(1); j++)
+ {
+#pragma warning disable CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+ xxx[i, j] = modelInputs[i].Speech[j];
+#pragma warning restore CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+ }
+ continue;
+ }
+ float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength];
+ float[]? curr_speech = modelInputs[i].Speech;
+ float[] padspeech = new float[max_speech_length];
+ // ///////////////////////////////////////////////////
+ var arr_neg_mean = _onnxSession.ModelMetadata.CustomMetadataMap["neg_mean"].ToString().Split(',').ToArray();
+ double[] neg_mean = arr_neg_mean.Select(x => (double)Convert.ToDouble(x)).ToArray();
+ var arr_inv_stddev = _onnxSession.ModelMetadata.CustomMetadataMap["inv_stddev"].ToString().Split(',').ToArray();
+ double[] inv_stddev = arr_inv_stddev.Select(x => (double)Convert.ToDouble(x)).ToArray();
+
+ int dim = neg_mean.Length;
+ for (int j = 0; j < max_speech_length; j++)
+ {
+ int k = new Random().Next(0, dim);
+ padspeech[j] = (float)((float)(0 + neg_mean[k]) * inv_stddev[k]);
+ }
+ Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length);
+ for (int j = 0; j < padspeech.Length; j++)
+ {
+#pragma warning disable CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+ xxx[i, j] = padspeech[j];
+#pragma warning restore CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
+ }
+
+ }
+ int s = 0;
+ for (int i = 0; i < xxx.GetLength(0); i++)
+ {
+ for (int j = 0; j < xxx.GetLength(1); j++)
+ {
+ speech[s] = xxx[i, j];
+ s++;
+ }
+ }
+ return speech;
+ }
+
+
+
+
+
+
+
+
+
+
+
+
+ }
+}
\ No newline at end of file
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj
new file mode 100644
index 0000000..4991517
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/AliFsmnVadSharp.csproj
@@ -0,0 +1,37 @@
+<Project Sdk="Microsoft.NET.Sdk">
+
+ <PropertyGroup>
+ <TargetFramework>net6.0</TargetFramework>
+ <ImplicitUsings>enable</ImplicitUsings>
+ <Nullable>enable</Nullable>
+ </PropertyGroup>
+
+ <ItemGroup>
+ <PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
+ <PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.15.0" />
+ <PackageReference Include="YamlDotNet" Version="13.1.0" />
+ </ItemGroup>
+
+ <ItemGroup>
+ <None Update="Lib\kaldi-native-fbank-dll.dll">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ <TargetPath>kaldi-native-fbank-dll.dll</TargetPath>
+ </None>
+ <None Update="speech_fsmn_vad_zh-cn-16k-common-pytorch\example\0.wav">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </None>
+ <None Update="speech_fsmn_vad_zh-cn-16k-common-pytorch\example\1.wav">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </None>
+ <None Update="speech_fsmn_vad_zh-cn-16k-common-pytorch\model.onnx">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </None>
+ <None Update="speech_fsmn_vad_zh-cn-16k-common-pytorch\vad.mvn">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </None>
+ <None Update="speech_fsmn_vad_zh-cn-16k-common-pytorch\vad.yaml">
+ <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
+ </None>
+ </ItemGroup>
+
+</Project>
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs
new file mode 100644
index 0000000..af0ad36
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KaldiNativeFbank.cs
@@ -0,0 +1,40 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using System.Runtime.InteropServices;
+using AliFsmnVadSharp.Struct;
+
+namespace AliFsmnVadSharp.DLL
+{
+ public static class KaldiNativeFbank
+ {
+ private const string dllName = @"kaldi-native-fbank-dll";
+
+ [DllImport(dllName, EntryPoint = "GetFbankOptions", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr GetFbankOptions(float dither, bool snip_edges, float sample_rate, int num_bins, float frame_shift = 10.0f, float frame_length = 25.0f, float energy_floor = 0.0f, bool debug_mel = false, string window_type = "hamming");
+
+ [DllImport(dllName, EntryPoint = "GetOnlineFbank", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern KnfOnlineFbank GetOnlineFbank(IntPtr opts);
+
+ [DllImport(dllName, EntryPoint = "AcceptWaveform", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void AcceptWaveform(KnfOnlineFbank knfOnlineFbank, float sample_rate, float[] samples, int samples_size);
+
+ [DllImport(dllName, EntryPoint = "InputFinished", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void InputFinished(KnfOnlineFbank knfOnlineFbank);
+
+ [DllImport(dllName, EntryPoint = "GetNumFramesReady", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int GetNumFramesReady(KnfOnlineFbank knfOnlineFbank);
+
+ [DllImport(dllName, EntryPoint = "AcceptWaveformxxx", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern FbankDatas AcceptWaveformxxx(KnfOnlineFbank knfOnlineFbank, float sample_rate, float[] samples, int samples_size);
+
+ [DllImport(dllName, EntryPoint = "GetFbank", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void GetFbank(KnfOnlineFbank knfOnlineFbank,int frame, ref FbankData pData);
+
+ [DllImport(dllName, EntryPoint = "GetFbanks", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void GetFbanks(KnfOnlineFbank knfOnlineFbank, int framesNum, ref FbankDatas fbankDatas);
+
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs
new file mode 100644
index 0000000..45549b2
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/DLL/KnfOnlineFbank.cs
@@ -0,0 +1,26 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.DLL
+{
+ internal struct FbankData
+ {
+ public IntPtr data;
+ public int data_length;
+ };
+
+ internal struct FbankDatas
+ {
+ public IntPtr data;
+ public int data_length;
+ };
+
+ internal struct KnfOnlineFbank
+ {
+ public IntPtr impl;
+ };
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs b/funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs
new file mode 100644
index 0000000..ce519b1
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/E2EVadModel.cs
@@ -0,0 +1,717 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using AliFsmnVadSharp.Model;
+
+namespace AliFsmnVadSharp
+{
+ enum VadStateMachine
+ {
+ kVadInStateStartPointNotDetected = 1,
+ kVadInStateInSpeechSegment = 2,
+ kVadInStateEndPointDetected = 3,
+ }
+ enum VadDetectMode
+ {
+ kVadSingleUtteranceDetectMode = 0,
+ kVadMutipleUtteranceDetectMode = 1,
+ }
+
+
+ internal class E2EVadModel
+ {
+ private VadPostConfEntity _vad_opts = new VadPostConfEntity();
+ private WindowDetector _windows_detector = new WindowDetector();
+ private bool _is_final = false;
+ private int _data_buf_start_frame = 0;
+ private int _frm_cnt = 0;
+ private int _latest_confirmed_speech_frame = 0;
+ private int _lastest_confirmed_silence_frame = -1;
+ private int _continous_silence_frame_count = 0;
+ private int _vad_state_machine = (int)VadStateMachine.kVadInStateStartPointNotDetected;
+ private int _confirmed_start_frame = -1;
+ private int _confirmed_end_frame = -1;
+ private int _number_end_time_detected = 0;
+ private int _sil_frame = 0;
+ private int[] _sil_pdf_ids = new int[0];
+ private double _noise_average_decibel = -100.0D;
+ private bool _pre_end_silence_detected = false;
+ private bool _next_seg = true;
+
+ private List<E2EVadSpeechBufWithDoaEntity> _output_data_buf;
+ private int _output_data_buf_offset = 0;
+ private List<E2EVadFrameProbEntity> _frame_probs = new List<E2EVadFrameProbEntity>();
+ private int _max_end_sil_frame_cnt_thresh = 800 - 150;
+ private float _speech_noise_thres = 0.6F;
+ private float[,,] _scores = null;
+ private int _idx_pre_chunk = 0;
+ private bool _max_time_out = false;
+ private List<double> _decibel = new List<double>();
+ private int _data_buf_size = 0;
+ private int _data_buf_all_size = 0;
+
+ public E2EVadModel(VadPostConfEntity vadPostConfEntity)
+ {
+ _vad_opts = vadPostConfEntity;
+ _windows_detector = new WindowDetector(_vad_opts.window_size_ms,
+ _vad_opts.sil_to_speech_time_thres,
+ _vad_opts.speech_to_sil_time_thres,
+ _vad_opts.frame_in_ms);
+ AllResetDetection();
+ }
+
+ private void AllResetDetection()
+ {
+ _is_final = false;
+ _data_buf_start_frame = 0;
+ _frm_cnt = 0;
+ _latest_confirmed_speech_frame = 0;
+ _lastest_confirmed_silence_frame = -1;
+ _continous_silence_frame_count = 0;
+ _vad_state_machine = (int)VadStateMachine.kVadInStateStartPointNotDetected;
+ _confirmed_start_frame = -1;
+ _confirmed_end_frame = -1;
+ _number_end_time_detected = 0;
+ _sil_frame = 0;
+ _sil_pdf_ids = _vad_opts.sil_pdf_ids;
+ _noise_average_decibel = -100.0F;
+ _pre_end_silence_detected = false;
+ _next_seg = true;
+
+ _output_data_buf = new List<E2EVadSpeechBufWithDoaEntity>();
+ _output_data_buf_offset = 0;
+ _frame_probs = new List<E2EVadFrameProbEntity>();
+ _max_end_sil_frame_cnt_thresh = _vad_opts.max_end_silence_time - _vad_opts.speech_to_sil_time_thres;
+ _speech_noise_thres = _vad_opts.speech_noise_thres;
+ _scores = null;
+ _idx_pre_chunk = 0;
+ _max_time_out = false;
+ _decibel = new List<double>();
+ _data_buf_size = 0;
+ _data_buf_all_size = 0;
+ ResetDetection();
+ }
+
+ private void ResetDetection()
+ {
+ _continous_silence_frame_count = 0;
+ _latest_confirmed_speech_frame = 0;
+ _lastest_confirmed_silence_frame = -1;
+ _confirmed_start_frame = -1;
+ _confirmed_end_frame = -1;
+ _vad_state_machine = (int)VadStateMachine.kVadInStateStartPointNotDetected;
+ _windows_detector.Reset();
+ _sil_frame = 0;
+ _frame_probs = new List<E2EVadFrameProbEntity>();
+ }
+
+ private void ComputeDecibel(float[] waveform)
+ {
+ int frame_sample_length = (int)(_vad_opts.frame_length_ms * _vad_opts.sample_rate / 1000);
+ int frame_shift_length = (int)(_vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000);
+ if (_data_buf_all_size == 0)
+ {
+ _data_buf_all_size = waveform.Length;
+ _data_buf_size = _data_buf_all_size;
+ }
+ else
+ {
+ _data_buf_all_size += waveform.Length;
+ }
+
+ for (int offset = 0; offset < waveform.Length - frame_sample_length + 1; offset += frame_shift_length)
+ {
+ float[] _waveform_chunk = new float[frame_sample_length];
+ Array.Copy(waveform, offset, _waveform_chunk, 0, _waveform_chunk.Length);
+ float[] _waveform_chunk_pow = _waveform_chunk.Select(x => (float)Math.Pow((double)x, 2)).ToArray();
+ _decibel.Add(
+ 10 * Math.Log10(
+ _waveform_chunk_pow.Sum() + 0.000001
+ )
+ );
+ }
+
+ }
+
+ private void ComputeScores(float[,,] scores)
+ {
+ _vad_opts.nn_eval_block_size = scores.GetLength(1);
+ _frm_cnt += scores.GetLength(1);
+ _scores = scores;
+ }
+
+ private void PopDataBufTillFrame(int frame_idx)// need check again
+ {
+ while (_data_buf_start_frame < frame_idx)
+ {
+ if (_data_buf_size >= (int)(_vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000))
+ {
+ _data_buf_start_frame += 1;
+ _data_buf_size = _data_buf_all_size - _data_buf_start_frame * (int)(_vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000);
+ }
+ }
+ }
+
+ private void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point,
+ bool last_frm_is_end_point, bool end_point_is_sent_end)
+ {
+ PopDataBufTillFrame(start_frm);
+ int expected_sample_number = (int)(frm_cnt * _vad_opts.sample_rate * _vad_opts.frame_in_ms / 1000);
+ if (last_frm_is_end_point)
+ {
+ int extra_sample = Math.Max(0, (int)(_vad_opts.frame_length_ms * _vad_opts.sample_rate / 1000 - _vad_opts.sample_rate * _vad_opts.frame_in_ms / 1000));
+ expected_sample_number += (int)(extra_sample);
+ }
+
+ if (end_point_is_sent_end)
+ {
+ expected_sample_number = Math.Max(expected_sample_number, _data_buf_size);
+ }
+ if (_data_buf_size < expected_sample_number)
+ {
+ Console.WriteLine("error in calling pop data_buf\n");
+ }
+
+ if (_output_data_buf.Count == 0 || first_frm_is_start_point)
+ {
+ _output_data_buf.Add(new E2EVadSpeechBufWithDoaEntity());
+ _output_data_buf.Last().Reset();
+ _output_data_buf.Last().start_ms = start_frm * _vad_opts.frame_in_ms;
+ _output_data_buf.Last().end_ms = _output_data_buf.Last().start_ms;
+ _output_data_buf.Last().doa = 0;
+ }
+
+ E2EVadSpeechBufWithDoaEntity cur_seg = _output_data_buf.Last();
+ if (cur_seg.end_ms != start_frm * _vad_opts.frame_in_ms)
+ {
+ Console.WriteLine("warning\n");
+ }
+
+ int out_pos = cur_seg.buffer.Length; // cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+ int data_to_pop = 0;
+ if (end_point_is_sent_end)
+ {
+ data_to_pop = expected_sample_number;
+ }
+ else
+ {
+ data_to_pop = (int)(frm_cnt * _vad_opts.frame_in_ms * _vad_opts.sample_rate / 1000);
+ }
+ if (data_to_pop > _data_buf_size)
+ {
+ Console.WriteLine("VAD data_to_pop is bigger than _data_buf_size!!!\n");
+ data_to_pop = _data_buf_size;
+ expected_sample_number = _data_buf_size;
+ }
+
+
+ cur_seg.doa = 0;
+ for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++)
+ {
+ out_pos += 1;
+ }
+ for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++)
+ {
+ out_pos += 1;
+ }
+
+ if (cur_seg.end_ms != start_frm * _vad_opts.frame_in_ms)
+ {
+ Console.WriteLine("Something wrong with the VAD algorithm\n");
+ }
+
+ _data_buf_start_frame += frm_cnt;
+ cur_seg.end_ms = (start_frm + frm_cnt) * _vad_opts.frame_in_ms;
+ if (first_frm_is_start_point)
+ {
+ cur_seg.contain_seg_start_point = true;
+ }
+
+ if (last_frm_is_end_point)
+ {
+ cur_seg.contain_seg_end_point = true;
+ }
+ }
+
+ private void OnSilenceDetected(int valid_frame)
+ {
+ _lastest_confirmed_silence_frame = valid_frame;
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected)
+ {
+ PopDataBufTillFrame(valid_frame);
+ }
+
+ }
+
+ private void OnVoiceDetected(int valid_frame)
+ {
+ _latest_confirmed_speech_frame = valid_frame;
+ PopDataToOutputBuf(valid_frame, 1, false, false, false);
+ }
+
+ private void OnVoiceStart(int start_frame, bool fake_result = false)
+ {
+ if (_vad_opts.do_start_point_detection)
+ {
+ //do nothing
+ }
+ if (_confirmed_start_frame != -1)
+ {
+
+ Console.WriteLine("not reset vad properly\n");
+ }
+ else
+ {
+ _confirmed_start_frame = start_frame;
+ }
+ if (!fake_result || _vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected)
+ {
+
+ PopDataToOutputBuf(_confirmed_start_frame, 1, true, false, false);
+ }
+ }
+
+ private void OnVoiceEnd(int end_frame, bool fake_result, bool is_last_frame)
+ {
+ for (int t = _latest_confirmed_speech_frame + 1; t < end_frame; t++)
+ {
+ OnVoiceDetected(t);
+ }
+ if (_vad_opts.do_end_point_detection)
+ {
+ //do nothing
+ }
+ if (_confirmed_end_frame != -1)
+ {
+ Console.WriteLine("not reset vad properly\n");
+ }
+ else
+ {
+ _confirmed_end_frame = end_frame;
+ }
+ if (!fake_result)
+ {
+ _sil_frame = 0;
+ PopDataToOutputBuf(_confirmed_end_frame, 1, false, true, is_last_frame);
+ }
+ _number_end_time_detected += 1;
+ }
+
+ private void MaybeOnVoiceEndIfLastFrame(bool is_final_frame, int cur_frm_idx)
+ {
+ if (is_final_frame)
+ {
+ OnVoiceEnd(cur_frm_idx, false, true);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+ }
+
+ private int GetLatency()
+ {
+ return (int)(LatencyFrmNumAtStartPoint() * _vad_opts.frame_in_ms);
+ }
+
+ private int LatencyFrmNumAtStartPoint()
+ {
+ int vad_latency = _windows_detector.GetWinSize();
+ if (_vad_opts.do_extend != 0)
+ {
+ vad_latency += (int)(_vad_opts.lookback_time_start_point / _vad_opts.frame_in_ms);
+ }
+ return vad_latency;
+ }
+
+ private FrameState GetFrameState(int t)
+ {
+
+ FrameState frame_state = FrameState.kFrameStateInvalid;
+ double cur_decibel = _decibel[t];
+ double cur_snr = cur_decibel - _noise_average_decibel;
+ if (cur_decibel < _vad_opts.decibel_thres)
+ {
+ frame_state = FrameState.kFrameStateSil;
+ DetectOneFrame(frame_state, t, false);
+ return frame_state;
+ }
+
+
+ double sum_score = 0.0D;
+ double noise_prob = 0.0D;
+ Trace.Assert(_sil_pdf_ids.Length == _vad_opts.silence_pdf_num, "");
+ if (_sil_pdf_ids.Length > 0)
+ {
+ Trace.Assert(_scores.GetLength(0) == 1, "鍙敮鎸乥atch_size = 1鐨勬祴璇�"); // 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+ float[] sil_pdf_scores = new float[_sil_pdf_ids.Length];
+ int j = 0;
+ foreach (int sil_pdf_id in _sil_pdf_ids)
+ {
+ sil_pdf_scores[j] = _scores[0,t - _idx_pre_chunk,sil_pdf_id];
+ j++;
+ }
+ sum_score = sil_pdf_scores.Length == 0 ? 0 : sil_pdf_scores.Sum();
+ noise_prob = Math.Log(sum_score) * _vad_opts.speech_2_noise_ratio;
+ double total_score = 1.0D;
+ sum_score = total_score - sum_score;
+ }
+ double speech_prob = Math.Log(sum_score);
+ if (_vad_opts.output_frame_probs)
+ {
+ E2EVadFrameProbEntity frame_prob = new E2EVadFrameProbEntity();
+ frame_prob.noise_prob = noise_prob;
+ frame_prob.speech_prob = speech_prob;
+ frame_prob.score = sum_score;
+ frame_prob.frame_id = t;
+ _frame_probs.Add(frame_prob);
+ }
+
+ if (Math.Exp(speech_prob) >= Math.Exp(noise_prob) + _speech_noise_thres)
+ {
+ if (cur_snr >= _vad_opts.snr_thres && cur_decibel >= _vad_opts.decibel_thres)
+ {
+ frame_state = FrameState.kFrameStateSpeech;
+ }
+ else
+ {
+ frame_state = FrameState.kFrameStateSil;
+ }
+ }
+ else
+ {
+ frame_state = FrameState.kFrameStateSil;
+ if (_noise_average_decibel < -99.9)
+ {
+ _noise_average_decibel = cur_decibel;
+ }
+ else
+ {
+ _noise_average_decibel = (cur_decibel + _noise_average_decibel * (_vad_opts.noise_frame_num_used_for_snr - 1)) / _vad_opts.noise_frame_num_used_for_snr;
+ }
+ }
+ return frame_state;
+ }
+
+ public SegmentEntity[] DefaultCall(float[,,] score, float[] waveform,
+ bool is_final = false, int max_end_sil = 800, bool online = false
+ )
+ {
+ _max_end_sil_frame_cnt_thresh = max_end_sil - _vad_opts.speech_to_sil_time_thres;
+ // compute decibel for each frame
+ ComputeDecibel(waveform);
+ ComputeScores(score);
+ if (!is_final)
+ {
+ DetectCommonFrames();
+ }
+ else
+ {
+ DetectLastFrames();
+ }
+ int batchSize = score.GetLength(0);
+ SegmentEntity[] segments = new SegmentEntity[batchSize];
+ for (int batch_num = 0; batch_num < batchSize; batch_num++) // only support batch_size = 1 now
+ {
+ List<int[]> segment_batch = new List<int[]>();
+ if (_output_data_buf.Count > 0)
+ {
+ for (int i = _output_data_buf_offset; i < _output_data_buf.Count; i++)
+ {
+ int start_ms;
+ int end_ms;
+ if (online)
+ {
+ if (!_output_data_buf[i].contain_seg_start_point)
+ {
+ continue;
+ }
+ if (!_next_seg && !_output_data_buf[i].contain_seg_end_point)
+ {
+ continue;
+ }
+ start_ms = _next_seg ? _output_data_buf[i].start_ms : -1;
+ if (_output_data_buf[i].contain_seg_end_point)
+ {
+ end_ms = _output_data_buf[i].end_ms;
+ _next_seg = true;
+ _output_data_buf_offset += 1;
+ }
+ else
+ {
+ end_ms = -1;
+ _next_seg = false;
+ }
+ }
+ else
+ {
+ if (!is_final && (!_output_data_buf[i].contain_seg_start_point || !_output_data_buf[i].contain_seg_end_point))
+ {
+ continue;
+ }
+ start_ms = _output_data_buf[i].start_ms;
+ end_ms = _output_data_buf[i].end_ms;
+ _output_data_buf_offset += 1;
+
+ }
+ int[] segment_ms = new int[] { start_ms, end_ms };
+ segment_batch.Add(segment_ms);
+
+ }
+
+ }
+
+ if (segment_batch.Count > 0)
+ {
+ if (segments[batch_num] == null)
+ {
+ segments[batch_num] = new SegmentEntity();
+ }
+ segments[batch_num].Segment.AddRange(segment_batch);
+ }
+ }
+
+ if (is_final)
+ {
+ // reset class variables and clear the dict for the next query
+ AllResetDetection();
+ }
+
+ return segments;
+ }
+
+ private int DetectCommonFrames()
+ {
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateEndPointDetected)
+ {
+ return 0;
+ }
+ for (int i = _vad_opts.nn_eval_block_size - 1; i > -1; i += -1)
+ {
+ FrameState frame_state = FrameState.kFrameStateInvalid;
+ frame_state = GetFrameState(_frm_cnt - 1 - i);
+ DetectOneFrame(frame_state, _frm_cnt - 1 - i, false);
+ }
+
+ _idx_pre_chunk += _scores.GetLength(1)* _scores.GetLength(0); //_scores.shape[1];
+ return 0;
+ }
+
+ private int DetectLastFrames()
+ {
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateEndPointDetected)
+ {
+ return 0;
+ }
+ for (int i = _vad_opts.nn_eval_block_size - 1; i > -1; i += -1)
+ {
+ FrameState frame_state = FrameState.kFrameStateInvalid;
+ frame_state = GetFrameState(_frm_cnt - 1 - i);
+ if (i != 0)
+ {
+ DetectOneFrame(frame_state, _frm_cnt - 1 - i, false);
+ }
+ else
+ {
+ DetectOneFrame(frame_state, _frm_cnt - 1, true);
+ }
+
+
+ }
+
+ return 0;
+ }
+
+ private void DetectOneFrame(FrameState cur_frm_state, int cur_frm_idx, bool is_final_frame)
+ {
+ FrameState tmp_cur_frm_state = FrameState.kFrameStateInvalid;
+ if (cur_frm_state == FrameState.kFrameStateSpeech)
+ {
+ if (Math.Abs(1.0) > _vad_opts.fe_prior_thres)//Fabs
+ {
+ tmp_cur_frm_state = FrameState.kFrameStateSpeech;
+ }
+ else
+ {
+ tmp_cur_frm_state = FrameState.kFrameStateSil;
+ }
+ }
+ else if (cur_frm_state == FrameState.kFrameStateSil)
+ {
+ tmp_cur_frm_state = FrameState.kFrameStateSil;
+ }
+
+ AudioChangeState state_change = _windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx);
+ int frm_shift_in_ms = _vad_opts.frame_in_ms;
+ if (AudioChangeState.kChangeStateSil2Speech == state_change)
+ {
+ int silence_frame_count = _continous_silence_frame_count; // no used
+ _continous_silence_frame_count = 0;
+ _pre_end_silence_detected = false;
+ int start_frame = 0;
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected)
+ {
+ start_frame = Math.Max(_data_buf_start_frame, cur_frm_idx - LatencyFrmNumAtStartPoint());
+ OnVoiceStart(start_frame);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateInSpeechSegment;
+ for (int t = start_frame + 1; t < cur_frm_idx + 1; t++)
+ {
+ OnVoiceDetected(t);
+ }
+
+ }
+ else if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment)
+ {
+ for (int t = _latest_confirmed_speech_frame + 1; t < cur_frm_idx; t++)
+ {
+ OnVoiceDetected(t);
+ }
+ if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms)
+ {
+ OnVoiceEnd(cur_frm_idx, false, false);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+
+ else if (!is_final_frame)
+ {
+ OnVoiceDetected(cur_frm_idx);
+ }
+ else
+ {
+ MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+ }
+
+ }
+ else
+ {
+ return;
+ }
+ }
+ else if (AudioChangeState.kChangeStateSpeech2Sil == state_change)
+ {
+ _continous_silence_frame_count = 0;
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected)
+ { return; }
+ else if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment)
+ {
+ if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms)
+ {
+ OnVoiceEnd(cur_frm_idx, false, false);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+ else if (!is_final_frame)
+ {
+ OnVoiceDetected(cur_frm_idx);
+ }
+ else
+ {
+ MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+ }
+
+ }
+ else
+ {
+ return;
+ }
+ }
+ else if (AudioChangeState.kChangeStateSpeech2Speech == state_change)
+ {
+ _continous_silence_frame_count = 0;
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment)
+ {
+ if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms)
+ {
+ _max_time_out = true;
+ OnVoiceEnd(cur_frm_idx, false, false);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+ else if (!is_final_frame)
+ {
+ OnVoiceDetected(cur_frm_idx);
+ }
+ else
+ {
+ MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+ }
+ }
+ else
+ {
+ return;
+ }
+
+ }
+ else if (AudioChangeState.kChangeStateSil2Sil == state_change)
+ {
+ _continous_silence_frame_count += 1;
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateStartPointNotDetected)
+ {
+ // silence timeout, return zero length decision
+ if (((_vad_opts.detect_mode == (int)VadDetectMode.kVadSingleUtteranceDetectMode) && (
+ _continous_silence_frame_count * frm_shift_in_ms > _vad_opts.max_start_silence_time)) || (is_final_frame && _number_end_time_detected == 0))
+ {
+ for (int t = _lastest_confirmed_silence_frame + 1; t < cur_frm_idx; t++)
+ {
+ OnSilenceDetected(t);
+ }
+ OnVoiceStart(0, true);
+ OnVoiceEnd(0, true, false);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+ else
+ {
+ if (cur_frm_idx >= LatencyFrmNumAtStartPoint())
+ {
+ OnSilenceDetected(cur_frm_idx - LatencyFrmNumAtStartPoint());
+ }
+ }
+ }
+ else if (_vad_state_machine == (int)VadStateMachine.kVadInStateInSpeechSegment)
+ {
+ if (_continous_silence_frame_count * frm_shift_in_ms >= _max_end_sil_frame_cnt_thresh)
+ {
+ int lookback_frame = (int)(_max_end_sil_frame_cnt_thresh / frm_shift_in_ms);
+ if (_vad_opts.do_extend != 0)
+ {
+ lookback_frame -= (int)(_vad_opts.lookahead_time_end_point / frm_shift_in_ms);
+ lookback_frame -= 1;
+ lookback_frame = Math.Max(0, lookback_frame);
+ }
+
+ OnVoiceEnd(cur_frm_idx - lookback_frame, false, false);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+ else if (cur_frm_idx - _confirmed_start_frame + 1 > _vad_opts.max_single_segment_time / frm_shift_in_ms)
+ {
+ OnVoiceEnd(cur_frm_idx, false, false);
+ _vad_state_machine = (int)VadStateMachine.kVadInStateEndPointDetected;
+ }
+
+ else if (_vad_opts.do_extend != 0 && !is_final_frame)
+ {
+ if (_continous_silence_frame_count <= (int)(_vad_opts.lookahead_time_end_point / frm_shift_in_ms))
+ {
+ OnVoiceDetected(cur_frm_idx);
+ }
+ }
+
+ else
+ {
+ MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+ }
+ }
+ else
+ {
+ return;
+ }
+
+ }
+
+ if (_vad_state_machine == (int)VadStateMachine.kVadInStateEndPointDetected && _vad_opts.detect_mode == (int)VadDetectMode.kVadMutipleUtteranceDetectMode)
+ {
+ ResetDetection();
+ }
+
+ }
+
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll b/funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll
new file mode 100644
index 0000000..cddc940
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Lib/kaldi-native-fbank-dll.dll
Binary files differ
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs
new file mode 100644
index 0000000..2f93df1
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/CmvnEntity.cs
@@ -0,0 +1,17 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ internal class CmvnEntity
+ {
+ private List<float> _means = new List<float>();
+ private List<float> _vars = new List<float>();
+
+ public List<float> Means { get => _means; set => _means = value; }
+ public List<float> Vars { get => _vars; set => _vars = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs
new file mode 100644
index 0000000..58a4ca9
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadFrameProbEntity.cs
@@ -0,0 +1,23 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ internal class E2EVadFrameProbEntity
+ {
+ private double _noise_prob = 0.0F;
+ private double _speech_prob = 0.0F;
+ private double _score = 0.0F;
+ private int _frame_id = 0;
+ private int _frm_state = 0;
+
+ public double noise_prob { get => _noise_prob; set => _noise_prob = value; }
+ public double speech_prob { get => _speech_prob; set => _speech_prob = value; }
+ public double score { get => _score; set => _score = value; }
+ public int frame_id { get => _frame_id; set => _frame_id = value; }
+ public int frm_state { get => _frm_state; set => _frm_state = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs
new file mode 100644
index 0000000..8c2e7f7
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/E2EVadSpeechBufWithDoaEntity.cs
@@ -0,0 +1,98 @@
+// AliFsmnVadSharp, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
+// AliFsmnVadSharp.Model.E2EVadSpeechBufWithDoaEntity
+internal class E2EVadSpeechBufWithDoaEntity
+{
+ private int _start_ms = 0;
+
+ private int _end_ms = 0;
+
+ private byte[]? _buffer;
+
+ private bool _contain_seg_start_point = false;
+
+ private bool _contain_seg_end_point = false;
+
+ private int _doa = 0;
+
+ public int start_ms
+ {
+ get
+ {
+ return _start_ms;
+ }
+ set
+ {
+ _start_ms = value;
+ }
+ }
+
+ public int end_ms
+ {
+ get
+ {
+ return _end_ms;
+ }
+ set
+ {
+ _end_ms = value;
+ }
+ }
+
+ public byte[]? buffer
+ {
+ get
+ {
+ return _buffer;
+ }
+ set
+ {
+ _buffer = value;
+ }
+ }
+
+ public bool contain_seg_start_point
+ {
+ get
+ {
+ return _contain_seg_start_point;
+ }
+ set
+ {
+ _contain_seg_start_point = value;
+ }
+ }
+
+ public bool contain_seg_end_point
+ {
+ get
+ {
+ return _contain_seg_end_point;
+ }
+ set
+ {
+ _contain_seg_end_point = value;
+ }
+ }
+
+ public int doa
+ {
+ get
+ {
+ return _doa;
+ }
+ set
+ {
+ _doa = value;
+ }
+ }
+
+ public void Reset()
+ {
+ _start_ms = 0;
+ _end_ms = 0;
+ _buffer = new byte[0];
+ _contain_seg_start_point = false;
+ _contain_seg_end_point = false;
+ _doa = 0;
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs
new file mode 100644
index 0000000..8365b12
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/EncoderConfEntity.cs
@@ -0,0 +1,35 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ public class EncoderConfEntity
+ {
+ private int _input_dim=400;
+ private int _input_affineDim = 140;
+ private int _fsmn_layers = 4;
+ private int _linear_dim = 250;
+ private int _proj_dim = 128;
+ private int _lorder = 20;
+ private int _rorder = 0;
+ private int _lstride = 1;
+ private int _rstride = 0;
+ private int _output_dffine_dim = 140;
+ private int _output_dim = 248;
+
+ public int input_dim { get => _input_dim; set => _input_dim = value; }
+ public int input_affine_dim { get => _input_affineDim; set => _input_affineDim = value; }
+ public int fsmn_layers { get => _fsmn_layers; set => _fsmn_layers = value; }
+ public int linear_dim { get => _linear_dim; set => _linear_dim = value; }
+ public int proj_dim { get => _proj_dim; set => _proj_dim = value; }
+ public int lorder { get => _lorder; set => _lorder = value; }
+ public int rorder { get => _rorder; set => _rorder = value; }
+ public int lstride { get => _lstride; set => _lstride = value; }
+ public int rstride { get => _rstride; set => _rstride = value; }
+ public int output_affine_dim { get => _output_dffine_dim; set => _output_dffine_dim = value; }
+ public int output_dim { get => _output_dim; set => _output_dim = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs
new file mode 100644
index 0000000..22bb35a
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/FrontendConfEntity.cs
@@ -0,0 +1,29 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ public class FrontendConfEntity
+ {
+ private int _fs = 16000;
+ private string _window = "hamming";
+ private int _n_mels = 80;
+ private int _frame_length = 25;
+ private int _frame_shift = 10;
+ private float _dither = 0.0F;
+ private int _lfr_m = 5;
+ private int _lfr_n = 1;
+
+ public int fs { get => _fs; set => _fs = value; }
+ public string window { get => _window; set => _window = value; }
+ public int n_mels { get => _n_mels; set => _n_mels = value; }
+ public int frame_length { get => _frame_length; set => _frame_length = value; }
+ public int frame_shift { get => _frame_shift; set => _frame_shift = value; }
+ public float dither { get => _dither; set => _dither = value; }
+ public int lfr_m { get => _lfr_m; set => _lfr_m = value; }
+ public int lfr_n { get => _lfr_n; set => _lfr_n = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs
new file mode 100644
index 0000000..bdb715d
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/SegmentEntity.cs
@@ -0,0 +1,22 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ public class SegmentEntity
+ {
+ private List<int[]> _segment=new List<int[]>();
+ private List<float[]> _waveform=new List<float[]>();
+
+ public List<int[]> Segment { get => _segment; set => _segment = value; }
+ public List<float[]> Waveform { get => _waveform; set => _waveform = value; }
+ //public SegmentEntity()
+ //{
+ // int[] t=new int[0];
+ // _segment.Add(t);
+ //}
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs
new file mode 100644
index 0000000..fcd63d8
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadInputEntity.cs
@@ -0,0 +1,23 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ internal class VadInputEntity
+ {
+ private float[]? _speech;
+ private int _speechLength;
+ private List<float[]> _inCaches = new List<float[]>();
+ private float[]? _waveform;
+ private E2EVadModel _vad_scorer;
+
+ public float[]? Speech { get => _speech; set => _speech = value; }
+ public int SpeechLength { get => _speechLength; set => _speechLength = value; }
+ public List<float[]> InCaches { get => _inCaches; set => _inCaches = value; }
+ public float[] Waveform { get => _waveform; set => _waveform = value; }
+ internal E2EVadModel VadScorer { get => _vad_scorer; set => _vad_scorer = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs
new file mode 100644
index 0000000..fa8639e
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadOutputEntity.cs
@@ -0,0 +1,19 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ internal class VadOutputEntity
+ {
+ private float[,,]? _scores;
+ private List<float[]> _outCaches=new List<float[]>();
+ private float[]? _waveform;
+
+ public float[,,]? Scores { get => _scores; set => _scores = value; }
+ public List<float[]> OutCaches { get => _outCaches; set => _outCaches = value; }
+ public float[] Waveform { get => _waveform; set => _waveform = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs
new file mode 100644
index 0000000..e566cf2
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadPostConfEntity.cs
@@ -0,0 +1,72 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ public class VadPostConfEntity
+ {
+ private int _sample_rate= 16000;
+ private int _detect_mode = 1 ;
+ private int _snr_mode = 0;
+ private int _max_end_silence_time = 800;
+ private int _max_start_silence_time = 3000;
+ private bool _do_start_point_detection = true;
+ private bool _do_end_point_detection = true;
+ private int _window_size_ms = 200;
+ private int _sil_to_speech_time_thres = 150;
+ private int _speech_to_sil_time_thres = 150;
+ private float _speech_2_noise_ratio = 1.0F;
+ private int _do_extend = 1;
+ private int _lookback_time_start_point = 200;
+ private int _lookahead_time_end_point = 100;
+ private int _max_single_segment_time = 60000;
+ private int _nn_eval_block_size = 8;
+ private int _dcd_block_size = 4;
+ private float _snr_thres = -100.0F;
+ private int _noise_frame_num_used_for_snr = 100;
+ private float _decibel_thres = -100.0F;
+ private float _speech_noise_thres = 0.6F;
+ private float _fe_prior_thres = 0.0001F;
+ private int _silence_pdf_num = 1;
+ private int[] _sil_pdf_ids = new int[] {0};
+ private float _speech_noise_thresh_low = -0.1F;
+ private float _speech_noise_thresh_high = 0.3F;
+ private bool _output_frame_probs = false;
+ private int _frame_in_ms = 10;
+ private int _frame_length_ms = 25;
+
+ public int sample_rate { get => _sample_rate; set => _sample_rate = value; }
+ public int detect_mode { get => _detect_mode; set => _detect_mode = value; }
+ public int snr_mode { get => _snr_mode; set => _snr_mode = value; }
+ public int max_end_silence_time { get => _max_end_silence_time; set => _max_end_silence_time = value; }
+ public int max_start_silence_time { get => _max_start_silence_time; set => _max_start_silence_time = value; }
+ public bool do_start_point_detection { get => _do_start_point_detection; set => _do_start_point_detection = value; }
+ public bool do_end_point_detection { get => _do_end_point_detection; set => _do_end_point_detection = value; }
+ public int window_size_ms { get => _window_size_ms; set => _window_size_ms = value; }
+ public int sil_to_speech_time_thres { get => _sil_to_speech_time_thres; set => _sil_to_speech_time_thres = value; }
+ public int speech_to_sil_time_thres { get => _speech_to_sil_time_thres; set => _speech_to_sil_time_thres = value; }
+ public float speech_2_noise_ratio { get => _speech_2_noise_ratio; set => _speech_2_noise_ratio = value; }
+ public int do_extend { get => _do_extend; set => _do_extend = value; }
+ public int lookback_time_start_point { get => _lookback_time_start_point; set => _lookback_time_start_point = value; }
+ public int lookahead_time_end_point { get => _lookahead_time_end_point; set => _lookahead_time_end_point = value; }
+ public int max_single_segment_time { get => _max_single_segment_time; set => _max_single_segment_time = value; }
+ public int nn_eval_block_size { get => _nn_eval_block_size; set => _nn_eval_block_size = value; }
+ public int dcd_block_size { get => _dcd_block_size; set => _dcd_block_size = value; }
+ public float snr_thres { get => _snr_thres; set => _snr_thres = value; }
+ public int noise_frame_num_used_for_snr { get => _noise_frame_num_used_for_snr; set => _noise_frame_num_used_for_snr = value; }
+ public float decibel_thres { get => _decibel_thres; set => _decibel_thres = value; }
+ public float speech_noise_thres { get => _speech_noise_thres; set => _speech_noise_thres = value; }
+ public float fe_prior_thres { get => _fe_prior_thres; set => _fe_prior_thres = value; }
+ public int silence_pdf_num { get => _silence_pdf_num; set => _silence_pdf_num = value; }
+ public int[] sil_pdf_ids { get => _sil_pdf_ids; set => _sil_pdf_ids = value; }
+ public float speech_noise_thresh_low { get => _speech_noise_thresh_low; set => _speech_noise_thresh_low = value; }
+ public float speech_noise_thresh_high { get => _speech_noise_thresh_high; set => _speech_noise_thresh_high = value; }
+ public bool output_frame_probs { get => _output_frame_probs; set => _output_frame_probs = value; }
+ public int frame_in_ms { get => _frame_in_ms; set => _frame_in_ms = value; }
+ public int frame_length_ms { get => _frame_length_ms; set => _frame_length_ms = value; }
+
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs
new file mode 100644
index 0000000..65e77ed
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Model/VadYamlEntity.cs
@@ -0,0 +1,27 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp.Model
+{
+ internal class VadYamlEntity
+ {
+ private int _input_size;
+ private string _frontend = "wav_frontend";
+ private FrontendConfEntity _frontend_conf=new FrontendConfEntity();
+ private string _model = "e2evad";
+ private string _encoder = "fsmn";
+ private EncoderConfEntity _encoder_conf=new EncoderConfEntity();
+ private VadPostConfEntity _vad_post_conf=new VadPostConfEntity();
+
+ public int input_size { get => _input_size; set => _input_size = value; }
+ public string frontend { get => _frontend; set => _frontend = value; }
+ public string model { get => _model; set => _model = value; }
+ public string encoder { get => _encoder; set => _encoder = value; }
+ public FrontendConfEntity frontend_conf { get => _frontend_conf; set => _frontend_conf = value; }
+ public EncoderConfEntity encoder_conf { get => _encoder_conf; set => _encoder_conf = value; }
+ public VadPostConfEntity vad_post_conf { get => _vad_post_conf; set => _vad_post_conf = value; }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs
new file mode 100644
index 0000000..bbad3dc
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Struct/FbankData.cs
@@ -0,0 +1,6 @@
+锘縰sing System.Runtime.InteropServices;
+
+namespace AliFsmnVadSharp.Struct
+{
+
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs b/funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs
new file mode 100644
index 0000000..0b460ff
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/Utils/YamlHelper.cs
@@ -0,0 +1,28 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using System.Text.Json;
+using YamlDotNet.Serialization;
+
+namespace AliFsmnVadSharp.Utils
+{
+ internal class YamlHelper
+ {
+ public static T ReadYaml<T>(string yamlFilePath)
+ {
+ if (!File.Exists(yamlFilePath))
+ {
+#pragma warning disable CS8603 // 鍙兘杩斿洖 null 寮曠敤銆�
+ return default(T);
+#pragma warning restore CS8603 // 鍙兘杩斿洖 null 寮曠敤銆�
+ }
+ StreamReader yamlReader = File.OpenText(yamlFilePath);
+ Deserializer yamlDeserializer = new Deserializer();
+ T info = yamlDeserializer.Deserialize<T>(yamlReader);
+ yamlReader.Close();
+ return info;
+ }
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs b/funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs
new file mode 100644
index 0000000..2c5b50f
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/WavFrontend.cs
@@ -0,0 +1,185 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using AliFsmnVadSharp.Model;
+using AliFsmnVadSharp.DLL;
+using AliFsmnVadSharp.Struct;
+using System.Runtime.InteropServices;
+
+namespace AliFsmnVadSharp
+{
+ internal class WavFrontend
+ {
+ private string _mvnFilePath;
+ private FrontendConfEntity _frontendConfEntity;
+ IntPtr _opts = IntPtr.Zero;
+ private CmvnEntity _cmvnEntity;
+
+ private static int _fbank_beg_idx = 0;
+
+ public WavFrontend(string mvnFilePath, FrontendConfEntity frontendConfEntity)
+ {
+ _mvnFilePath = mvnFilePath;
+ _frontendConfEntity = frontendConfEntity;
+ _fbank_beg_idx = 0;
+ _opts = KaldiNativeFbank.GetFbankOptions(
+ dither: _frontendConfEntity.dither,
+ snip_edges: true,
+ sample_rate: _frontendConfEntity.fs,
+ num_bins: _frontendConfEntity.n_mels
+ );
+ _cmvnEntity = LoadCmvn(mvnFilePath);
+ }
+
+ public float[] GetFbank(float[] samples)
+ {
+ float sample_rate = _frontendConfEntity.fs;
+ samples = samples.Select((float x) => x * 32768f).ToArray();
+ // method1
+ //FbankDatas fbankDatas = new FbankDatas();
+ //KaldiNativeFbank.GetFbanks(_knfOnlineFbank, framesNum,ref fbankDatas);
+ // method2
+ KnfOnlineFbank _knfOnlineFbank = KaldiNativeFbank.GetOnlineFbank(_opts);
+ KaldiNativeFbank.AcceptWaveform(_knfOnlineFbank, sample_rate, samples, samples.Length);
+ KaldiNativeFbank.InputFinished(_knfOnlineFbank);
+ int framesNum = KaldiNativeFbank.GetNumFramesReady(_knfOnlineFbank);
+ float[] fbanks = new float[framesNum * 80];
+ for (int i = 0; i < framesNum; i++)
+ {
+ FbankData fbankData = new FbankData();
+ KaldiNativeFbank.GetFbank(_knfOnlineFbank, i, ref fbankData);
+ float[] _fbankData = new float[fbankData.data_length];
+ Marshal.Copy(fbankData.data, _fbankData, 0, fbankData.data_length);
+ Array.Copy(_fbankData, 0, fbanks, i * 80, _fbankData.Length);
+ fbankData.data = IntPtr.Zero;
+ _fbankData = null;
+ }
+
+ samples = null;
+ GC.Collect();
+ return fbanks;
+ }
+
+
+ public float[] LfrCmvn(float[] fbanks)
+ {
+ float[] features = fbanks;
+ if (_frontendConfEntity.lfr_m != 1 || _frontendConfEntity.lfr_n != 1)
+ {
+ features = ApplyLfr(fbanks, _frontendConfEntity.lfr_m, _frontendConfEntity.lfr_n);
+ }
+ if (_cmvnEntity != null)
+ {
+ features = ApplyCmvn(features);
+ }
+ return features;
+ }
+
+ private float[] ApplyCmvn(float[] inputs)
+ {
+ var arr_neg_mean = _cmvnEntity.Means;
+ float[] neg_mean = arr_neg_mean.Select(x => (float)Convert.ToDouble(x)).ToArray();
+ var arr_inv_stddev = _cmvnEntity.Vars;
+ float[] inv_stddev = arr_inv_stddev.Select(x => (float)Convert.ToDouble(x)).ToArray();
+
+ int dim = neg_mean.Length;
+ int num_frames = inputs.Length / dim;
+
+ for (int i = 0; i < num_frames; i++)
+ {
+ for (int k = 0; k != dim; ++k)
+ {
+ inputs[dim * i + k] = (inputs[dim * i + k] + neg_mean[k]) * inv_stddev[k];
+ }
+ }
+ return inputs;
+ }
+
+ public float[] ApplyLfr(float[] inputs, int lfr_m, int lfr_n)
+ {
+ int t = inputs.Length / 80;
+ int t_lfr = (int)Math.Floor((double)(t / lfr_n));
+ float[] input_0 = new float[80];
+ Array.Copy(inputs, 0, input_0, 0, 80);
+ int tile_x = (lfr_m - 1) / 2;
+ t = t + tile_x;
+ float[] inputs_temp = new float[t * 80];
+ for (int i = 0; i < tile_x; i++)
+ {
+ Array.Copy(input_0, 0, inputs_temp, tile_x * 80, 80);
+ }
+ Array.Copy(inputs, 0, inputs_temp, tile_x * 80, inputs.Length);
+ inputs = inputs_temp;
+
+ float[] LFR_outputs = new float[t_lfr * lfr_m * 80];
+ for (int i = 0; i < t_lfr; i++)
+ {
+ if (lfr_m <= t - i * lfr_n)
+ {
+ Array.Copy(inputs, i * lfr_n * 80, LFR_outputs, i* lfr_m * 80, lfr_m * 80);
+ }
+ else
+ {
+ // process last LFR frame
+ int num_padding = lfr_m - (t - i * lfr_n);
+ float[] frame = new float[lfr_m * 80];
+ Array.Copy(inputs, i * lfr_n * 80, frame, 0, (t - i * lfr_n) * 80);
+
+ for (int j = 0; j < num_padding; j++)
+ {
+ Array.Copy(inputs, (t - 1) * 80, frame, (lfr_m - num_padding + j) * 80, 80);
+ }
+ Array.Copy(frame, 0, LFR_outputs, i * lfr_m * 80, frame.Length);
+ }
+ }
+ return LFR_outputs;
+ }
+
+ private CmvnEntity LoadCmvn(string mvnFilePath)
+ {
+ List<float> means_list = new List<float>();
+ List<float> vars_list = new List<float>();
+ FileStreamOptions options = new FileStreamOptions();
+ options.Access = FileAccess.Read;
+ options.Mode = FileMode.Open;
+ StreamReader srtReader = new StreamReader(mvnFilePath, options);
+ int i = 0;
+ while (!srtReader.EndOfStream)
+ {
+ string? strLine = srtReader.ReadLine();
+ if (!string.IsNullOrEmpty(strLine))
+ {
+ if (strLine.StartsWith("<AddShift>"))
+ {
+ i=1;
+ continue;
+ }
+ if (strLine.StartsWith("<Rescale>"))
+ {
+ i = 2;
+ continue;
+ }
+ if (strLine.StartsWith("<LearnRateCoef>") && i==1)
+ {
+ string[] add_shift_line = strLine.Substring(strLine.IndexOf("[") + 1, strLine.LastIndexOf("]") - strLine.IndexOf("[") - 1).Split(" ");
+ means_list = add_shift_line.Where(x => !string.IsNullOrEmpty(x)).Select(x => float.Parse(x.Trim())).ToList();
+ continue;
+ }
+ if (strLine.StartsWith("<LearnRateCoef>") && i==2)
+ {
+ string[] rescale_line = strLine.Substring(strLine.IndexOf("[") + 1, strLine.LastIndexOf("]") - strLine.IndexOf("[") - 1).Split(" ");
+ vars_list = rescale_line.Where(x => !string.IsNullOrEmpty(x)).Select(x => float.Parse(x.Trim())).ToList();
+ continue;
+ }
+ }
+ }
+ CmvnEntity cmvnEntity = new CmvnEntity();
+ cmvnEntity.Means = means_list;
+ cmvnEntity.Vars = vars_list;
+ return cmvnEntity;
+ }
+
+ }
+}
diff --git a/funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs b/funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs
new file mode 100644
index 0000000..785af32
--- /dev/null
+++ b/funasr/runtime/csharp/AliFsmnVadSharp/WindowDetector.cs
@@ -0,0 +1,156 @@
+锘縰sing System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace AliFsmnVadSharp
+{
+ public enum FrameState
+ {
+ kFrameStateInvalid = -1,
+ kFrameStateSpeech = 1,
+ kFrameStateSil = 0
+ }
+
+ /// <summary>
+ /// final voice/unvoice state per frame
+ /// </summary>
+ public enum AudioChangeState
+ {
+ kChangeStateSpeech2Speech = 0,
+ kChangeStateSpeech2Sil = 1,
+ kChangeStateSil2Sil = 2,
+ kChangeStateSil2Speech = 3,
+ kChangeStateNoBegin = 4,
+ kChangeStateInvalid = 5
+ }
+
+
+ internal class WindowDetector
+ {
+ private int _window_size_ms = 0; //window_size_ms;
+ private int _sil_to_speech_time = 0; //sil_to_speech_time;
+ private int _speech_to_sil_time = 0; //speech_to_sil_time;
+ private int _frame_size_ms = 0; //frame_size_ms;
+
+ private int _win_size_frame = 0;
+ private int _win_sum = 0;
+ private int[] _win_state = new int[0];// * _win_size_frame; // 鍒濆鍖栫獥
+
+ private int _cur_win_pos = 0;
+ private int _pre_frame_state = (int)FrameState.kFrameStateSil;
+ private int _cur_frame_state = (int)FrameState.kFrameStateSil;
+ private int _sil_to_speech_frmcnt_thres = 0; //int(sil_to_speech_time / frame_size_ms);
+ private int _speech_to_sil_frmcnt_thres = 0; //int(speech_to_sil_time / frame_size_ms);
+
+ private int _voice_last_frame_count = 0;
+ private int _noise_last_frame_count = 0;
+ private int _hydre_frame_count = 0;
+
+ public WindowDetector()
+ {
+
+ }
+
+ public WindowDetector(int window_size_ms, int sil_to_speech_time, int speech_to_sil_time, int frame_size_ms)
+ {
+ _window_size_ms = window_size_ms;
+ _sil_to_speech_time = sil_to_speech_time;
+ _speech_to_sil_time = speech_to_sil_time;
+ _frame_size_ms = frame_size_ms;
+
+ _win_size_frame = (int)(window_size_ms / frame_size_ms);
+ _win_sum = 0;
+ _win_state = new int[_win_size_frame];//[0] * _win_size_frame; // 鍒濆鍖栫獥
+
+ _cur_win_pos = 0;
+ _pre_frame_state = (int)FrameState.kFrameStateSil;
+ _cur_frame_state = (int)FrameState.kFrameStateSil;
+ _sil_to_speech_frmcnt_thres = (int)(sil_to_speech_time / frame_size_ms);
+ _speech_to_sil_frmcnt_thres = (int)(speech_to_sil_time / frame_size_ms);
+
+ _voice_last_frame_count = 0;
+ _noise_last_frame_count = 0;
+ _hydre_frame_count = 0;
+ }
+
+ public void Reset()
+ {
+ _cur_win_pos = 0;
+ _win_sum = 0;
+ _win_state = new int[_win_size_frame];
+ _pre_frame_state = (int)FrameState.kFrameStateSil;
+ _cur_frame_state = (int)FrameState.kFrameStateSil;
+ _voice_last_frame_count = 0;
+ _noise_last_frame_count = 0;
+ _hydre_frame_count = 0;
+ }
+
+
+ public int GetWinSize()
+ {
+ return _win_size_frame;
+ }
+
+ public AudioChangeState DetectOneFrame(FrameState frameState, int frame_count)
+ {
+
+
+ _cur_frame_state = (int)FrameState.kFrameStateSil;
+ if (frameState == FrameState.kFrameStateSpeech)
+ {
+ _cur_frame_state = 1;
+ }
+
+ else if (frameState == FrameState.kFrameStateSil)
+ {
+ _cur_frame_state = 0;
+ }
+
+ else
+ {
+ return AudioChangeState.kChangeStateInvalid;
+ }
+
+ _win_sum -= _win_state[_cur_win_pos];
+ _win_sum += _cur_frame_state;
+ _win_state[_cur_win_pos] = _cur_frame_state;
+ _cur_win_pos = (_cur_win_pos + 1) % _win_size_frame;
+
+ if (_pre_frame_state == (int)FrameState.kFrameStateSil && _win_sum >= _sil_to_speech_frmcnt_thres)
+ {
+ _pre_frame_state = (int)FrameState.kFrameStateSpeech;
+ return AudioChangeState.kChangeStateSil2Speech;
+ }
+
+
+ if (_pre_frame_state == (int)FrameState.kFrameStateSpeech && _win_sum <= _speech_to_sil_frmcnt_thres)
+ {
+ _pre_frame_state = (int)FrameState.kFrameStateSil;
+ return AudioChangeState.kChangeStateSpeech2Sil;
+ }
+
+
+ if (_pre_frame_state == (int)FrameState.kFrameStateSil)
+ {
+ return AudioChangeState.kChangeStateSil2Sil;
+ }
+
+ if (_pre_frame_state == (int)FrameState.kFrameStateSpeech)
+ {
+ return AudioChangeState.kChangeStateSpeech2Speech;
+ }
+
+ return AudioChangeState.kChangeStateInvalid;
+ }
+
+ private int FrameSizeMs()
+ {
+ return _frame_size_ms;
+ }
+
+
+
+ }
+}
diff --git a/funasr/runtime/csharp/README.md b/funasr/runtime/csharp/README.md
new file mode 100644
index 0000000..68175cd
--- /dev/null
+++ b/funasr/runtime/csharp/README.md
@@ -0,0 +1,59 @@
+# AliFsmnVadSharp
+##### 绠�浠嬶細
+椤圭洰涓娇鐢ㄧ殑VAD妯″瀷鏄樋閲屽反宸磋揪鎽╅櫌鎻愪緵鐨凢SMN-Monophone VAD妯″瀷銆�
+**椤圭洰鍩轰簬Net 6.0锛屼娇鐢–#缂栧啓锛岃皟鐢∕icrosoft.ML.OnnxRuntime瀵筼nnx妯″瀷杩涜瑙g爜锛屾敮鎸佽法骞冲彴缂栬瘧銆傞」鐩互搴撶殑褰㈠紡杩涜璋冪敤锛岄儴缃查潪甯告柟渚裤��**
+VAD鏁翠綋娴佺▼鐨剅tf鍦�0.008宸﹀彸銆�
+
+##### 鐢ㄩ�旓細
+16k涓枃閫氱敤VAD妯″瀷锛氬彲鐢ㄤ簬妫�娴嬮暱璇煶鐗囨涓湁鏁堣闊崇殑璧锋鏃堕棿鐐�.
+FSMN-Monophone VAD鏄揪鎽╅櫌璇煶鍥㈤槦鎻愬嚭鐨勯珮鏁堣闊崇鐐规娴嬫ā鍨嬶紝鐢ㄤ簬妫�娴嬭緭鍏ラ煶棰戜腑鏈夋晥璇煶鐨勮捣姝㈡椂闂寸偣淇℃伅锛屽苟灏嗘娴嬪嚭鏉ョ殑鏈夋晥闊抽鐗囨杈撳叆璇嗗埆寮曟搸杩涜璇嗗埆锛屽噺灏戞棤鏁堣闊冲甫鏉ョ殑璇嗗埆閿欒銆�
+
+##### VAD甯哥敤鍙傛暟璋冩暣璇存槑锛堝弬鑰冿細vad.yaml鏂囦欢锛夛細
+max_end_silence_time锛氬熬閮ㄨ繛缁娴嬪埌澶氶暱鏃堕棿闈欓煶杩涜灏剧偣鍒ゅ仠锛屽弬鏁拌寖鍥�500ms锝�6000ms锛岄粯璁ゅ��800ms(璇ュ�艰繃浣庡鏄撳嚭鐜拌闊虫彁鍓嶆埅鏂殑鎯呭喌)銆�
+speech_noise_thres锛歴peech鐨勫緱鍒嗗噺鍘籲oise鐨勫緱鍒嗗ぇ浜庢鍊煎垯鍒ゆ柇涓簊peech锛屽弬鏁拌寖鍥达細锛�-1,1锛�
+鍙栧�艰秺瓒嬩簬-1锛屽櫔闊宠璇垽瀹氫负璇煶鐨勬鐜囪秺澶э紝FA瓒婇珮
+鍙栧�艰秺瓒嬩簬+1锛岃闊宠璇垽瀹氫负鍣煶鐨勬鐜囪秺澶э紝Pmiss瓒婇珮
+閫氬父鎯呭喌涓嬶紝璇ュ�间細鏍规嵁褰撳墠妯″瀷鍦ㄩ暱璇煶娴嬭瘯闆嗕笂鐨勬晥鏋滃彇balance
+
+##### 妯″瀷鑾峰彇
+
+##### 璋冪敤鏂瑰紡锛�
+###### 1.娣诲姞椤圭洰寮曠敤
+using AliFsmnVadSharp;
+
+###### 2.鍒濆鍖栨ā鍨嬪拰閰嶇疆
+```csharp
+string applicationBase = AppDomain.CurrentDomain.BaseDirectory;
+string modelFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx";
+string configFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.yaml";
+string mvnFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.mvn";
+int batchSize = 2;//鎵归噺瑙g爜
+AliFsmnVad aliFsmnVad = new AliFsmnVad(modelFilePath, configFilePath, mvnFilePath, batchSize);
+```
+###### 3.璋冪敤
+鏂规硶涓�(閫傜敤浜庡皬鏂囦欢)锛�
+```csharp
+SegmentEntity[] segments_duration = aliFsmnVad.GetSegments(samples);
+```
+鏂规硶浜�(閫傜敤浜庡ぇ鏂囦欢)锛�
+```csharp
+SegmentEntity[] segments_duration = aliFsmnVad.GetSegmentsByStep(samples);
+```
+###### 4.杈撳嚭缁撴灉锛�
+```
+load model and init config elapsed_milliseconds:463.5390625
+vad infer result:
+[[70,2340][2620,6200][6480,23670][23950,26250][26780,28990][29950,31430][31750,37600][38210,46900][47310,49630][49910,56460][56740,59540][59820,70450]]
+elapsed_milliseconds:662.796875
+total_duration:70470.625
+rtf:0.009405292985552491
+```
+杈撳嚭鐨勬暟鎹紝渚嬪锛歔70,2340]锛屾槸浠ユ绉掍负鍗曚綅鐨剆egement鐨勮捣姝㈡椂闂达紝鍙互浠ユ涓轰緷鎹闊抽杩涜鍒嗙墖銆傚叾涓潤闊冲櫔闊抽儴鍒嗗凡琚幓闄ゃ��
+
+鍏朵粬璇存槑锛�
+娴嬭瘯鐢ㄤ緥锛欰liFsmnVadSharp.Examples銆�
+娴嬭瘯鐜锛歸indows11銆�
+娴嬭瘯鐢ㄤ緥涓璼amples鐨勮绠�,浣跨敤鐨勬槸NAudio搴撱��
+
+閫氳繃浠ヤ笅閾炬帴浜嗚В鏇村锛�
+https://www.modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary
diff --git a/funasr/runtime/docs/SDK_advanced_guide_cn.md b/funasr/runtime/docs/SDK_advanced_guide_cn.md
new file mode 100644
index 0000000..534375c
--- /dev/null
+++ b/funasr/runtime/docs/SDK_advanced_guide_cn.md
@@ -0,0 +1,261 @@
+# FunASR绂荤嚎鏂囦欢杞啓鏈嶅姟寮�鍙戞寚鍗�
+
+FunASR鎻愪緵鍙竴閿湰鍦版垨鑰呬簯绔湇鍔″櫒閮ㄧ讲鐨勪腑鏂囩绾挎枃浠惰浆鍐欐湇鍔★紝鍐呮牳涓篎unASR宸插紑婧恟untime-SDK銆侳unASR-runtime缁撳悎浜嗚揪鎽╅櫌璇煶瀹為獙瀹ゅ湪Modelscope绀惧尯寮�婧愮殑璇煶绔偣妫�娴�(VAD)銆丳araformer-large璇煶璇嗗埆(ASR)銆佹爣鐐规娴�(PUNC) 绛夌浉鍏宠兘鍔涳紝鍙互鍑嗙‘銆侀珮鏁堢殑瀵归煶棰戣繘琛岄珮骞跺彂杞啓銆�
+
+鏈枃妗d负FunASR绂荤嚎鏂囦欢杞啓鏈嶅姟寮�鍙戞寚鍗椼�傚鏋滄偍鎯冲揩閫熶綋楠岀绾挎枃浠惰浆鍐欐湇鍔★紝璇峰弬鑰僃unASR绂荤嚎鏂囦欢杞啓鏈嶅姟涓�閿儴缃茬ず渚嬶紙[鐐瑰嚮姝ゅ](./SDK_tutorial_cn.md)锛夈��
+
+## Docker瀹夎
+
+涓嬭堪姝ラ涓烘墜鍔ㄥ畨瑁卍ocker鍙奷ocker闀滃儚鐨勬楠わ紝濡傛偍docker闀滃儚宸插惎鍔紝鍙互蹇界暐鏈楠わ細
+
+### docker鐜瀹夎
+```shell
+# Ubuntu锛�
+curl -fsSL https://test.docker.com -o test-docker.sh
+sudo sh test-docker.sh
+# Debian锛�
+curl -fsSL https://get.docker.com -o get-docker.sh
+sudo sh get-docker.sh
+# CentOS锛�
+curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun
+# MacOS锛�
+brew install --cask --appdir=/Applications docker
+```
+
+瀹夎璇﹁锛歨ttps://alibaba-damo-academy.github.io/FunASR/en/installation/docker.html
+
+### docker鍚姩
+
+```shell
+sudo systemctl start docker
+```
+
+### 闀滃儚鎷夊彇鍙婂惎鍔�
+
+閫氳繃涓嬭堪鍛戒护鎷夊彇骞跺惎鍔‵unASR runtime-SDK鐨刣ocker闀滃儚锛�
+
+```shell
+sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.0.1
+
+sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.0.1
+```
+
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+-p <瀹夸富鏈虹鍙�>:<鏄犲皠鍒癲ocker绔彛>
+濡傜ず渚嬶紝瀹夸富鏈�(ecs)绔彛10095鏄犲皠鍒癲ocker绔彛10095涓娿�傚墠鎻愭槸纭繚ecs瀹夊叏瑙勫垯鎵撳紑浜�10095绔彛銆�
+-v <瀹夸富鏈鸿矾寰�>:<鎸傝浇鑷砫ocker璺緞>
+濡傜ず渚嬶紝瀹夸富鏈鸿矾寰�/root鎸傝浇鑷砫ocker璺緞/workspace/models
+```
+
+
+## 鏈嶅姟绔惎鍔�
+
+docker鍚姩涔嬪悗锛屽惎鍔� funasr-wss-server鏈嶅姟绋嬪簭锛�
+
+funasr-wss-server鏀寔浠嶮odelscope涓嬭浇妯″瀷锛岃缃ā鍨嬩笅杞藉湴鍧�锛�--download-model-dir锛岄粯璁や负/workspace/models锛夊強model ID锛�--model-dir銆�--vad-dir銆�--punc-dir锛�,绀轰緥濡備笅锛�
+```shell
+cd /workspace/FunASR/funasr/runtime/websocket/build/bin
+./funasr-wss-server \
+ --download-model-dir /workspace/models \
+ --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+ --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+ --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
+ --decoder-thread-num 32 \
+ --io-thread-num 8 \
+ --port 10095 \
+ --certfile ../../../ssl_key/server.crt \
+ --keyfile ../../../ssl_key/server.key
+ ```
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+--download-model-dir #妯″瀷涓嬭浇鍦板潃锛岄�氳繃璁剧疆model ID浠嶮odelscope涓嬭浇妯″瀷
+--model-dir # modelscope model ID
+--quantize # True涓洪噺鍖朅SR妯″瀷锛孎alse涓洪潪閲忓寲ASR妯″瀷锛岄粯璁ゆ槸True
+--vad-dir # modelscope model ID
+--vad-quant # True涓洪噺鍖朧AD妯″瀷锛孎alse涓洪潪閲忓寲VAD妯″瀷锛岄粯璁ゆ槸True
+--punc-dir # modelscope model ID
+--punc-quant # True涓洪噺鍖朠UNC妯″瀷锛孎alse涓洪潪閲忓寲PUNC妯″瀷锛岄粯璁ゆ槸True
+--port # 鏈嶅姟绔洃鍚殑绔彛鍙凤紝榛樿涓� 10095
+--decoder-thread-num # 鏈嶅姟绔惎鍔ㄧ殑鎺ㄧ悊绾跨▼鏁帮紝榛樿涓� 8
+--io-thread-num # 鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
+--certfile <string> # ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt
+--keyfile <string> # ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
+```
+
+funasr-wss-server鍚屾椂涔熸敮鎸佷粠鏈湴璺緞鍔犺浇妯″瀷锛堟湰鍦版ā鍨嬭祫婧愬噯澶囪瑙乕妯″瀷璧勬簮鍑嗗](#anchor-1)锛夌ず渚嬪涓嬶細
+```shell
+cd /workspace/FunASR/funasr/runtime/websocket/build/bin
+./funasr-wss-server \
+ --model-dir /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+ --vad-dir /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+ --punc-dir /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
+ --decoder-thread-num 32 \
+ --io-thread-num 8 \
+ --port 10095 \
+ --certfile ../../../ssl_key/server.crt \
+ --keyfile ../../../ssl_key/server.key
+ ```
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+--model-dir # ASR妯″瀷璺緞锛岄粯璁や负锛�/workspace/models/asr
+--quantize # True涓洪噺鍖朅SR妯″瀷锛孎alse涓洪潪閲忓寲ASR妯″瀷锛岄粯璁ゆ槸True
+--vad-dir # VAD妯″瀷璺緞锛岄粯璁や负锛�/workspace/models/vad
+--vad-quant # True涓洪噺鍖朧AD妯″瀷锛孎alse涓洪潪閲忓寲VAD妯″瀷锛岄粯璁ゆ槸True
+--punc-dir # PUNC妯″瀷璺緞锛岄粯璁や负锛�/workspace/models/punc
+--punc-quant # True涓洪噺鍖朠UNC妯″瀷锛孎alse涓洪潪閲忓寲PUNC妯″瀷锛岄粯璁ゆ槸True
+--port # 鏈嶅姟绔洃鍚殑绔彛鍙凤紝榛樿涓� 10095
+--decoder-thread-num # 鏈嶅姟绔惎鍔ㄧ殑鎺ㄧ悊绾跨▼鏁帮紝榛樿涓� 8
+--io-thread-num # 鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
+--certfile <string> # ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt
+--keyfile <string> # ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
+```
+
+## <a id="anchor-1">妯″瀷璧勬簮鍑嗗</a>
+
+濡傛灉鎮ㄩ�夋嫨閫氳繃funasr-wss-server浠嶮odelscope涓嬭浇妯″瀷锛屽彲浠ヨ烦杩囨湰姝ラ銆�
+
+FunASR绂荤嚎鏂囦欢杞啓鏈嶅姟涓殑vad銆乤sr鍜宲unc妯″瀷璧勬簮鍧囨潵鑷狹odelscope锛屾ā鍨嬪湴鍧�璇﹁涓嬭〃锛�
+
+| 妯″瀷 | Modelscope閾炬帴 |
+|------|------------------------------------------------------------------------------------------------------------------|
+| VAD | https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary |
+| ASR | https://www.modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary |
+| PUNC | https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary |
+
+绂荤嚎鏂囦欢杞啓鏈嶅姟涓儴缃茬殑鏄噺鍖栧悗鐨凮NNX妯″瀷锛屼笅闈粙缁嶄笅濡備綍瀵煎嚭ONNX妯″瀷鍙婂叾閲忓寲锛氭偍鍙互閫夋嫨浠嶮odelscope瀵煎嚭ONNX妯″瀷銆佷粠鏈湴鏂囦欢瀵煎嚭ONNX妯″瀷鎴栬�呬粠finetune鍚庣殑璧勬簮瀵煎嚭妯″瀷锛�
+
+### 浠嶮odelscope瀵煎嚭ONNX妯″瀷
+
+浠嶮odelscope缃戠珯涓嬭浇瀵瑰簲model name鐨勬ā鍨嬶紝鐒跺悗瀵煎嚭閲忓寲鍚庣殑ONNX妯″瀷锛�
+
+```shell
+python -m funasr.export.export_model \
+--export-dir ./export \
+--type onnx \
+--quantize True \
+--model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch \
+--model-name damo/speech_fsmn_vad_zh-cn-16k-common-pytorch \
+--model-name damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch
+```
+
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+--model-name Modelscope涓婄殑妯″瀷鍚嶇О锛屼緥濡俤amo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+--export-dir ONNX妯″瀷瀵煎嚭鍦板潃
+--type 妯″瀷绫诲瀷锛岀洰鍓嶆敮鎸� ONNX銆乼orch
+--quantize int8妯″瀷閲忓寲
+```
+
+### 浠庢湰鍦版枃浠跺鍑篛NNX妯″瀷
+
+璁剧疆model name涓烘ā鍨嬫湰鍦拌矾寰勶紝瀵煎嚭閲忓寲鍚庣殑ONNX妯″瀷锛�
+
+```shell
+python -m funasr.export.export_model --model-name /workspace/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
+```
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+--model-name 妯″瀷鏈湴璺緞锛屼緥濡�/workspace/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+--export-dir ONNX妯″瀷瀵煎嚭鍦板潃
+--type 妯″瀷绫诲瀷锛岀洰鍓嶆敮鎸� ONNX銆乼orch
+--quantize int8妯″瀷閲忓寲
+```
+
+### 浠巉inetune鍚庣殑璧勬簮瀵煎嚭妯″瀷
+
+鍋囧鎮ㄦ兂閮ㄧ讲finetune鍚庣殑妯″瀷锛屽彲浠ュ弬鑰冨涓嬫楠わ細
+
+灏嗘偍finetune鍚庨渶瑕侀儴缃茬殑妯″瀷锛堜緥濡�10epoch.pb锛夛紝閲嶅懡鍚嶄负model.pb锛屽苟灏嗗師modelscope涓ā鍨媘odel.pb鏇挎崲鎺夛紝鍋囧鏇挎崲鍚庣殑妯″瀷璺緞涓�/path/to/finetune/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch锛岄�氳繃涓嬭堪鍛戒护鎶奻inetune鍚庣殑妯″瀷杞垚onnx妯″瀷锛�
+
+```shell
+python -m funasr.export.export_model --model-name /path/to/finetune/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
+```
+
+## 瀹㈡埛绔惎鍔�
+
+鍦ㄦ湇鍔″櫒涓婂畬鎴怓unASR绂荤嚎鏂囦欢杞啓鏈嶅姟閮ㄧ讲浠ュ悗锛屽彲浠ラ�氳繃濡備笅鐨勬楠ゆ潵娴嬭瘯鍜屼娇鐢ㄧ绾挎枃浠惰浆鍐欐湇鍔°�傜洰鍓岶unASR-bin鏀寔澶氱鏂瑰紡鍚姩瀹㈡埛绔紝濡備笅鏄熀浜巔ython-client銆乧++-client鐨勫懡浠よ瀹炰緥鍙婅嚜瀹氫箟瀹㈡埛绔疻ebsocket閫氫俊鍗忚锛�
+
+### python-client
+```shell
+python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results"
+```
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+--host # 鏈嶅姟绔痠p鍦板潃锛屾湰鏈烘祴璇曞彲璁剧疆涓� 127.0.0.1
+--port # 鏈嶅姟绔洃鍚鍙e彿
+--audio_in # 闊抽杈撳叆锛岃緭鍏ュ彲浠ユ槸锛歸av璺緞 鎴栬�� wav.scp璺緞锛坘aldi鏍煎紡鐨剋av list锛寃av_id \t wav_path锛�
+--output_dir # 璇嗗埆缁撴灉杈撳嚭璺緞
+--ssl # 鏄惁浣跨敤SSL鍔犲瘑锛岄粯璁や娇鐢�
+--mode # offline妯″紡
+```
+
+### c++-client锛�
+```shell
+. /funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
+```
+鍛戒护鍙傛暟浠嬬粛锛�
+```text
+--server-ip # 鏈嶅姟绔痠p鍦板潃锛屾湰鏈烘祴璇曞彲璁剧疆涓� 127.0.0.1
+--port # 鏈嶅姟绔洃鍚鍙e彿
+--wav-path # 闊抽杈撳叆锛岃緭鍏ュ彲浠ユ槸锛歸av璺緞 鎴栬�� wav.scp璺緞锛坘aldi鏍煎紡鐨剋av list锛寃av_id \t wav_path锛�
+--thread-num # 瀹㈡埛绔嚎绋嬫暟
+--is-ssl # 鏄惁浣跨敤SSL鍔犲瘑锛岄粯璁や娇鐢�
+```
+
+### 鑷畾涔夊鎴风锛�
+
+濡傛灉鎮ㄦ兂瀹氫箟鑷繁鐨刢lient锛寃ebsocket閫氫俊鍗忚涓猴細
+
+```text
+# 棣栨閫氫俊
+{"mode": "offline", "wav_name": wav_name, "is_speaking": True}
+# 鍙戦�亀av鏁版嵁
+bytes鏁版嵁
+# 鍙戦�佺粨鏉熸爣蹇�
+{"is_speaking": False}
+```
+
+## 濡備綍瀹氬埗鏈嶅姟閮ㄧ讲
+
+FunASR-runtime鐨勪唬鐮佸凡寮�婧愶紝濡傛灉鏈嶅姟绔拰瀹㈡埛绔笉鑳藉緢濂界殑婊¤冻鎮ㄧ殑闇�姹傦紝鎮ㄥ彲浠ユ牴鎹嚜宸辩殑闇�姹傝繘琛岃繘涓�姝ョ殑寮�鍙戯細
+### c++ 瀹㈡埛绔細
+
+https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/websocket
+
+### python 瀹㈡埛绔細
+
+https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket
+### c++ 鏈嶅姟绔細
+
+#### VAD
+```c++
+// VAD妯″瀷鐨勪娇鐢ㄥ垎涓篎smnVadInit鍜孎smnVadInfer涓や釜姝ラ锛�
+FUNASR_HANDLE vad_hanlde=FsmnVadInit(model_path, thread_num);
+// 鍏朵腑锛歮odel_path 鍖呭惈"model-dir"銆�"quantize"锛宼hread_num涓簅nnx绾跨▼鏁帮紱
+FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
+// 鍏朵腑锛歷ad_hanlde涓篎unOfflineInit杩斿洖鍊硷紝wav_file涓洪煶棰戣矾寰勶紝sampling_rate涓洪噰鏍风巼(榛樿16k)
+```
+
+浣跨敤绀轰緥璇﹁锛歨ttps://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
+
+#### ASR
+```text
+// ASR妯″瀷鐨勪娇鐢ㄥ垎涓篎unOfflineInit鍜孎unOfflineInfer涓や釜姝ラ锛�
+FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
+// 鍏朵腑锛歮odel_path 鍖呭惈"model-dir"銆�"quantize"锛宼hread_num涓簅nnx绾跨▼鏁帮紱
+FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000);
+// 鍏朵腑锛歛sr_hanlde涓篎unOfflineInit杩斿洖鍊硷紝wav_file涓洪煶棰戣矾寰勶紝sampling_rate涓洪噰鏍风巼(榛樿16k)
+```
+
+浣跨敤绀轰緥璇﹁锛歨ttps://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+
+#### PUNC
+```text
+// PUNC妯″瀷鐨勪娇鐢ㄥ垎涓篊TTransformerInit鍜孋TTransformerInfer涓や釜姝ラ锛�
+FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num);
+// 鍏朵腑锛歮odel_path 鍖呭惈"model-dir"銆�"quantize"锛宼hread_num涓簅nnx绾跨▼鏁帮紱
+FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
+// 鍏朵腑锛歱unc_hanlde涓篊TTransformerInit杩斿洖鍊硷紝txt_str涓烘枃鏈�
+```
+浣跨敤绀轰緥璇﹁锛歨ttps://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
diff --git a/funasr/runtime/docs/SDK_tutorial.md b/funasr/runtime/docs/SDK_tutorial.md
new file mode 100644
index 0000000..a4e46f1
--- /dev/null
+++ b/funasr/runtime/docs/SDK_tutorial.md
@@ -0,0 +1,336 @@
+# FunASR File Transcription Service Convenient Deployment Tutorial
+
+FunASR provides offline file transcription services that can be conveniently deployed on local or cloud servers. The core of the service is based on the open-source runtime-SDK of FunASR. It integrates various related capabilities, such as voice endpoint detection (VAD) and Paraformer-large speech recognition (ASR), as well as punctuation recovery (PUNC), which have been open-sourced by the speech laboratory of DAMO Academy on the Modelscope community. With these capabilities, the service can transcribe audio accurately and efficiently under high concurrency.
+
+## Installation and Start Service
+
+Environment Preparation and Configuration锛圼docs](./aliyun_server_tutorial.md)锛�
+
+### Downloading Tools and Deployment
+
+Run the following command to perform a one-click deployment of the FunASR runtime-SDK service. Follow the prompts to complete the deployment and running of the service. Currently, only Linux environments are supported, and for other environments, please refer to the Advanced SDK Development Guide. Due to network restrictions, the download of the funasr-runtime-deploy.sh one-click deployment tool may not proceed smoothly. If the tool has not been downloaded and entered into the one-click deployment tool after several seconds, please terminate it with Ctrl + C and run the following command again.
+
+```shell
+curl -O https://raw.githubusercontent.com/alibaba-damo-academy/FunASR-APP/main/TransAudio/funasr-runtime-deploy.sh; sudo bash funasr-runtime-deploy.sh install
+```
+
+#### Details of Configuration
+
+##### Choosing FunASR Docker Image
+
+We recommend selecting the "latest" tag to use our latest image, but you can also choose from our historical versions.
+
+```text
+[1/9]
+ Please choose the Docker image.
+ 1) registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+ 2) registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.1.0
+ Enter your choice: 1
+ You have chosen the Docker image: registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+```
+
+##### Choosing ASR/VAD/PUNC Models
+
+You can choose a model from ModelScope by name, or fill in the name of a model in ModelScope as <model_name>. The model will be automatically downloaded during Docker runtime. You can also select <model_path> to fill in the local model path on the host machine.
+
+```text
+[2/9]
+ Please input [Y/n] to confirm whether to automatically download model_id in ModelScope or use a local model.
+ [y] With the model in ModelScope, the model will be automatically downloaded to Docker(/workspace/models).
+ If you select both the local model and the model in ModelScope, select [y].
+ [n] Use the models on the localhost, the directory where the model is located will be mapped to Docker.
+ Setting confirmation[Y/n]:
+ You have chosen to use the model in ModelScope, please set the model ID in the next steps, and the model will be automatically downloaded in (/workspace/models) during the run.
+
+ Please enter the local path to download models, the corresponding path in Docker is /workspace/models.
+ Setting the local path to download models, default(/root/models):
+ The local path(/root/models) set will store models during the run.
+
+ [2.1/9]
+ Please select ASR model_id in ModelScope from the list below.
+ 1) damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ 2) model_name
+ 3) model_path
+ Enter your choice: 1
+ The model ID is damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ The model dir in Docker is /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+
+ [2.2/9]
+ Please select VAD model_id in ModelScope from the list below.
+ 1) damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ 2) model_name
+ 3) model_path
+ Enter your choice: 1
+ The model ID is damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ The model dir in Docker is /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+
+ [2.3/9]
+ Please select PUNC model_id in ModelScope from the list below.
+ 1) damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+ 2) model_name
+ 3) model_path
+ Enter your choice: 1
+ The model ID is damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+ The model dir in Docker is /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+```
+
+##### Enter the executable path of the FunASR service on the host machine
+
+Enter the host path of the executable of the FunASR service. It will be automatically mounted and run in Docker at runtime. If left blank, the default path in Docker will be set to /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server.
+
+```text
+[3/9]
+ Please enter the path to the excutor of the FunASR service on the localhost.
+ If not set, the default /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server in Docker is used.
+ Setting the path to the excutor of the FunASR service on the localhost:
+ Corresponding, the path of FunASR in Docker is /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server
+```
+
+##### Setting the port on the host machine for FunASR
+
+Setting the port on the host machine for Docker. The default port is 10095. Please ensure that this port is available.
+
+```text
+[4/9]
+ Please input the opened port in the host used for FunASR server.
+ Default: 10095
+ Setting the opened host port [1-65535]:
+ The port of the host is 10095
+ The port in Docker for FunASR server is 10095
+```
+
+
+##### Setting the number of inference threads for the FunASR service
+
+Setting the number of inference threads for the FunASR service. The default value is the number of cores on the host machine. The number of I/O threads for the service will also be automatically set to one-quarter of the number of inference threads.
+
+```text
+[5/9]
+ Please input thread number for FunASR decoder.
+ Default: 1
+ Setting the number of decoder thread:
+
+ The number of decoder threads is 1
+ The number of IO threads is 1
+```
+
+##### Displaying all set parameters for confirmation
+
+Displaying the parameters set in the previous 6 steps. Confirming will save all parameters to /var/funasr/config and start Docker. Otherwise, users will be prompted to reset the parameters.
+
+```text
+
+[6/9]
+ Show parameters of FunASR server setting and confirm to run ...
+
+ The current Docker image is : registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+ The model is downloaded or stored to this directory in local : /root/models
+ The model will be automatically downloaded to the directory : /workspace/models
+ The ASR model_id used : damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ The ASR model directory corresponds to the directory in Docker : /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ The VAD model_id used : damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ The VAD model directory corresponds to the directory in Docker : /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ The PUNC model_id used : damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+ The PUNC model directory corresponds to the directory in Docker: /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+
+ The path in the docker of the FunASR service executor : /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server
+ Set the host port used for use by the FunASR service : 10095
+ Set the docker port used by the FunASR service : 10095
+ Set the number of threads used for decoding the FunASR service : 1
+ Set the number of threads used for IO the FunASR service : 1
+
+ Please input [Y/n] to confirm the parameters.
+ [y] Verify that these parameters are correct and that the service will run.
+ [n] The parameters set are incorrect, it will be rolled out, please rerun.
+ read confirmation[Y/n]:
+
+ Will run FunASR server later ...
+ Parameters are stored in the file /var/funasr/config
+```
+
+##### Checking the Docker service
+
+Checking if Docker service is installed on the host machine. If not installed, installing and starting Docker
+
+```text
+[7/9]
+ Start install docker for ubuntu
+ Get docker installer: curl -fsSL https://test.docker.com -o test-docker.sh
+ Get docker run: sudo sh test-docker.sh
+# Executing docker install script, commit: c2de0811708b6d9015ed1a2c80f02c9b70c8ce7b
++ sh -c apt-get update -qq >/dev/null
++ sh -c DEBIAN_FRONTEND=noninteractive apt-get install -y -qq apt-transport-https ca-certificates curl >/dev/null
++ sh -c install -m 0755 -d /etc/apt/keyrings
++ sh -c curl -fsSL "https://download.docker.com/linux/ubuntu/gpg" | gpg --dearmor --yes -o /etc/apt/keyrings/docker.gpg
++ sh -c chmod a+r /etc/apt/keyrings/docker.gpg
++ sh -c echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu focal test" > /etc/apt/sources.list.d/docker.list
++ sh -c apt-get update -qq >/dev/null
++ sh -c DEBIAN_FRONTEND=noninteractive apt-get install -y -qq docker-ce docker-ce-cli containerd.io docker-compose-plugin docker-ce-rootless-extras docker-buildx-plugin >/dev/null
++ sh -c docker version
+Client: Docker Engine - Community
+ Version: 24.0.2
+
+ ...
+ ...
+
+ Docker install success, start docker server.
+```
+
+##### Downloading the FunASR Docker image
+
+Downloading and updating the FunASR Docker image selected in step 1.1
+
+```text
+[8/9]
+ Pull docker image(registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest)...
+funasr-runtime-cpu-0.0.1: Pulling from funasr_repo/funasr
+7608715873ec: Pull complete
+3e1014c56f38: Pull complete
+
+ ...
+ ...
+```
+
+##### Starting the FunASR Docker
+
+Starting the FunASR Docker and waiting for the model selected in step 1.2 to finish downloading and start the FunASR service
+
+```text
+[9/9]
+ Construct command and run docker ...
+943d8f02b4e5011b71953a0f6c1c1b9bc5aff63e5a96e7406c83e80943b23474
+
+ Loading models:
+ [ASR ][Done ][==================================================][100%][1.10MB/s][v1.2.1]
+ [VAD ][Done ][==================================================][100%][7.26MB/s][v1.2.0]
+ [PUNC][Done ][==================================================][100%][ 474kB/s][v1.1.7]
+ The service has been started.
+ If you want to see an example of how to use the client, you can run sudo bash funasr-runtime-deploy.sh -c .
+```
+
+#### Starting the deployed FunASR service
+
+If the computer is restarted or Docker is closed after one-click deployment, the following command can be used to start the FunASR service directly with the settings from the last one-click deployment.
+
+```shell
+sudo bash funasr-runtime-deploy.sh start
+```
+
+#### Shutting down the FunASR service
+
+```shell
+sudo bash funasr-runtime-deploy.sh stop
+```
+
+#### Restarting the FunASR service
+
+Restarting the FunASR service with the settings from the last one-click deployment
+
+```shell
+sudo bash funasr-runtime-deploy.sh restart
+```
+
+#### Replacing the model and restarting the FunASR service
+
+Replacing the currently used model and restarting the FunASR service. The model must be an ASR/VAD/PUNC model from ModelScope.
+
+```shell
+sudo bash scripts/funasr-runtime-deploy.sh update model <model ID in ModelScope>
+
+e.g
+sudo bash scripts/funasr-runtime-deploy.sh update model damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+```
+
+### How to test and use the offline file transcription service
+
+After completing the FunASR service deployment on the server, you can test and use the offline file transcription service by following these steps. Currently, command line running is supported for Python, C++, and Java client versions, as well as an HTML web page version that can be directly experienced in the browser. For more client language support, please refer to the "FunASR Advanced Development Guide" documentation.
+After the funasr-runtime-deploy.sh script finishes running, you can use the following command to automatically download the test samples to the funasr_samples directory in the current directory and run the program with the set parameters in an interactive manner:
+
+```shell
+sudo bash funasr-runtime-deploy.sh client
+```
+
+You can choose from the provided Python and Linux C++ sample programs. Taking the Python sample as an example:
+
+```text
+Will download sample tools for the client to show how speech recognition works.
+ Please select the client you want to run.
+ 1) Python
+ 2) Linux_Cpp
+ Enter your choice: 1
+
+ Please enter the IP of server, default(127.0.0.1):
+ Please enter the port of server, default(10095):
+ Please enter the audio path, default(/root/funasr_samples/audio/asr_example.wav):
+
+ Run pip3 install click>=8.0.4
+Looking in indexes: http://mirrors.cloud.aliyuncs.com/pypi/simple/
+Requirement already satisfied: click>=8.0.4 in /usr/local/lib/python3.8/dist-packages (8.1.3)
+
+ Run pip3 install -r /root/funasr_samples/python/requirements_client.txt
+Looking in indexes: http://mirrors.cloud.aliyuncs.com/pypi/simple/
+Requirement already satisfied: websockets in /usr/local/lib/python3.8/dist-packages (from -r /root/funasr_samples/python/requirements_client.txt (line 1)) (11.0.3)
+
+ Run python3 /root/funasr_samples/python/wss_client_asr.py --host 127.0.0.1 --port 10095 --mode offline --audio_in /root/funasr_samples/audio/asr_example.wav --send_without_sleep --output_dir ./funasr_samples/python
+
+ ...
+ ...
+
+ pid0_0: 娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨嬨��
+Exception: sent 1000 (OK); then received 1000 (OK)
+end
+
+ If failed, you can try (python3 /root/funasr_samples/python/wss_client_asr.py --host 127.0.0.1 --port 10095 --mode offline --audio_in /root/funasr_samples/audio/asr_example.wav --send_without_sleep --output_dir ./funasr_samples/python) in your Shell.
+
+```
+
+#### python-client
+
+If you want to directly run the client for testing, you can refer to the following simple instructions, taking the Python version as an example:
+```shell
+python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.wav" --send_without_sleep --output_dir "./results"
+```
+
+Command parameter instructions:
+
+```text
+--host: The IP address of the machine where the FunASR runtime-SDK service is deployed. The default is the local IP address (127.0.0.1). If the client and service are not on the same server, the IP address should be changed to that of the deployment machine.
+--port 10095: The deployment port number.
+--mode offline: Indicates offline file transcription.
+--audio_in: The audio file(s) to be transcribed, which can be a file path or a file list (wav.scp).
+--output_dir: The path to save the recognition results.
+```
+
+#### cpp-client
+
+```shell
+export LD_LIBRARY_PATH=/root/funasr_samples/cpp/libs:$LD_LIBRARY_PATH
+/root/funasr_samples/cpp/funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path /root/funasr_samples/audio/asr_example.wav
+```
+
+Command parameter instructions:
+
+```text
+--server-ip: The IP address of the machine where the FunASR runtime-SDK service is deployed. The default is the local IP address (127.0.0.1). If the client and service are not on the same server, the IP address should be changed to that of the deployment machine.
+--port 10095: The deployment port number.
+--wav-path: The audio file(s) to be transcribed, which can be a file path.
+```
+
+### Video demo
+
+[demo]()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/funasr/runtime/docs/SDK_tutorial_cn.md b/funasr/runtime/docs/SDK_tutorial_cn.md
new file mode 100644
index 0000000..a465501
--- /dev/null
+++ b/funasr/runtime/docs/SDK_tutorial_cn.md
@@ -0,0 +1,327 @@
+# FunASR绂荤嚎鏂囦欢杞啓鏈嶅姟渚挎嵎閮ㄧ讲鏁欑▼
+
+FunASR鎻愪緵鍙究鎹锋湰鍦版垨鑰呬簯绔湇鍔″櫒閮ㄧ讲鐨勭绾挎枃浠惰浆鍐欐湇鍔★紝鍐呮牳涓篎unASR宸插紑婧恟untime-SDK銆傞泦鎴愪簡杈炬懇闄㈣闊冲疄楠屽鍦∕odelscope绀惧尯寮�婧愮殑璇煶绔偣妫�娴�(VAD)銆丳araformer-large璇煶璇嗗埆(ASR)銆佹爣鐐规仮澶�(PUNC) 绛夌浉鍏宠兘鍔涳紝鍙互鍑嗙‘銆侀珮鏁堢殑瀵归煶棰戣繘琛岄珮骞跺彂杞啓銆�
+
+## 鐜瀹夎涓庡惎鍔ㄦ湇鍔�
+
+鏈嶅姟鍣ㄩ厤缃笌鐢宠锛堝厤璐硅瘯鐢�1锝�3涓湀锛夛紙[鐐瑰嚮姝ゅ](./aliyun_server_tutorial.md)锛�
+### 鑾峰緱鑴氭湰宸ュ叿骞朵竴閿儴缃�
+
+閫氳繃浠ヤ笅鍛戒护杩愯涓�閿儴缃叉湇鍔★紝鎸夌収鎻愮ず閫愭瀹屾垚FunASR runtime-SDK鏈嶅姟鐨勯儴缃插拰杩愯銆傜洰鍓嶆殏鏃朵粎鏀寔Linux鐜锛屽叾浠栫幆澧冨弬鑰冩枃妗楂橀樁寮�鍙戞寚鍗梋(./SDK_advanced_guide_cn.md)銆�
+鍙楅檺浜庣綉缁滐紝funasr-runtime-deploy.sh涓�閿儴缃插伐鍏风殑涓嬭浇鍙兘涓嶉『鍒╋紝閬囧埌鏁扮杩樻湭涓嬭浇杩涘叆涓�閿儴缃插伐鍏风殑鎯呭喌锛岃Ctrl + C 缁堟鍚庡啀娆¤繍琛屼互涓嬪懡浠ゃ��
+
+```shell
+curl -O https://raw.githubusercontent.com/alibaba-damo-academy/FunASR-APP/main/TransAudio/funasr-runtime-deploy.sh; sudo bash funasr-runtime-deploy.sh install
+```
+
+#### 鍚姩杩囩▼閰嶇疆璇﹁В
+
+##### 閫夋嫨FunASR Docker闀滃儚
+鎺ㄨ崘閫夋嫨latest浣跨敤鎴戜滑鐨勬渶鏂伴暅鍍忥紝涔熷彲閫夋嫨鍘嗗彶鐗堟湰銆�
+```text
+[1/9]
+ Please choose the Docker image.
+ 1) registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+ 2) registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.1.0
+ Enter your choice: 1
+ You have chosen the Docker image: registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+```
+
+##### 閫夋嫨ASR/VAD/PUNC妯″瀷
+
+浣犲彲浠ラ�夋嫨ModelScope涓殑妯″瀷锛屼篃鍙互閫�<model_name>鑷濉叆ModelScope涓殑妯″瀷鍚嶏紝灏嗕細鍦―ocker杩愯鏃惰嚜鍔ㄤ笅杞姐�傚悓鏃朵篃鍙互閫夋嫨<model_path>濉叆瀹夸富鏈轰腑鐨勬湰鍦版ā鍨嬭矾寰勩��
+
+```text
+[2/9]
+ Please input [Y/n] to confirm whether to automatically download model_id in ModelScope or use a local model.
+ [y] With the model in ModelScope, the model will be automatically downloaded to Docker(/workspace/models).
+ If you select both the local model and the model in ModelScope, select [y].
+ [n] Use the models on the localhost, the directory where the model is located will be mapped to Docker.
+ Setting confirmation[Y/n]:
+ You have chosen to use the model in ModelScope, please set the model ID in the next steps, and the model will be automatically downloaded in (/workspace/models) during the run.
+
+ Please enter the local path to download models, the corresponding path in Docker is /workspace/models.
+ Setting the local path to download models, default(/root/models):
+ The local path(/root/models) set will store models during the run.
+
+ [2.1/9]
+ Please select ASR model_id in ModelScope from the list below.
+ 1) damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ 2) model_name
+ 3) model_path
+ Enter your choice: 1
+ The model ID is damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ The model dir in Docker is /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+
+ [2.2/9]
+ Please select VAD model_id in ModelScope from the list below.
+ 1) damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ 2) model_name
+ 3) model_path
+ Enter your choice: 1
+ The model ID is damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ The model dir in Docker is /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+
+ [2.3/9]
+ Please select PUNC model_id in ModelScope from the list below.
+ 1) damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+ 2) model_name
+ 3) model_path
+ Enter your choice: 1
+ The model ID is damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+ The model dir in Docker is /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+```
+
+##### 杈撳叆瀹夸富鏈轰腑FunASR鏈嶅姟鍙墽琛岀▼搴忚矾寰�
+
+杈撳叆FunASR鏈嶅姟鍙墽琛岀▼搴忕殑瀹夸富鏈鸿矾寰勶紝Docker杩愯鏃跺皢鑷姩鎸傝浇鍒癉ocker涓繍琛屻�傞粯璁や笉杈撳叆鐨勬儏鍐典笅灏嗘寚瀹欴ocker涓粯璁ょ殑/workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server銆�
+
+```text
+[3/9]
+ Please enter the path to the excutor of the FunASR service on the localhost.
+ If not set, the default /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server in Docker is used.
+ Setting the path to the excutor of the FunASR service on the localhost:
+ Corresponding, the path of FunASR in Docker is /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server
+```
+
+##### 璁剧疆瀹夸富鏈烘彁渚涚粰FunASR鐨勭鍙�
+璁剧疆鎻愪緵缁橠ocker鐨勫涓绘満绔彛锛岄粯璁や负10095銆傝淇濊瘉姝ょ鍙e彲鐢ㄣ��
+```text
+[4/9]
+ Please input the opened port in the host used for FunASR server.
+ Default: 10095
+ Setting the opened host port [1-65535]:
+ The port of the host is 10095
+ The port in Docker for FunASR server is 10095
+```
+
+
+##### 璁剧疆FunASR鏈嶅姟鐨勬帹鐞嗙嚎绋嬫暟
+璁剧疆FunASR鏈嶅姟鐨勬帹鐞嗙嚎绋嬫暟锛岄粯璁や负瀹夸富鏈烘牳鏁帮紝鍚屾椂鑷姩璁剧疆鏈嶅姟鐨処O绾跨▼鏁帮紝涓烘帹鐞嗙嚎绋嬫暟鐨勫洓鍒嗕箣涓�銆�
+```text
+[5/9]
+ Please input thread number for FunASR decoder.
+ Default: 1
+ Setting the number of decoder thread:
+
+ The number of decoder threads is 1
+ The number of IO threads is 1
+```
+
+##### 鎵�鏈夎缃弬鏁板睍绀哄強纭
+
+灞曠ず鍓嶉潰6姝ヨ缃殑鍙傛暟锛岀‘璁ゅ垯灏嗘墍鏈夊弬鏁板瓨鍌ㄥ埌/var/funasr/config锛屽苟寮�濮嬪惎鍔―ocker锛屽惁鍒欐彁绀虹敤鎴疯繘琛岄噸鏂拌缃��
+
+```text
+
+[6/9]
+ Show parameters of FunASR server setting and confirm to run ...
+
+ The current Docker image is : registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest
+ The model is downloaded or stored to this directory in local : /root/models
+ The model will be automatically downloaded to the directory : /workspace/models
+ The ASR model_id used : damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ The ASR model directory corresponds to the directory in Docker : /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+ The VAD model_id used : damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ The VAD model directory corresponds to the directory in Docker : /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx
+ The PUNC model_id used : damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+ The PUNC model directory corresponds to the directory in Docker: /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+
+ The path in the docker of the FunASR service executor : /workspace/FunASR/funasr/runtime/websocket/build/bin/funasr-wss-server
+ Set the host port used for use by the FunASR service : 10095
+ Set the docker port used by the FunASR service : 10095
+ Set the number of threads used for decoding the FunASR service : 1
+ Set the number of threads used for IO the FunASR service : 1
+
+ Please input [Y/n] to confirm the parameters.
+ [y] Verify that these parameters are correct and that the service will run.
+ [n] The parameters set are incorrect, it will be rolled out, please rerun.
+ read confirmation[Y/n]:
+
+ Will run FunASR server later ...
+ Parameters are stored in the file /var/funasr/config
+```
+
+##### 妫�鏌ocker鏈嶅姟
+
+妫�鏌ュ綋鍓嶅涓绘満鏄惁瀹夎浜咲ocker鏈嶅姟锛岃嫢鏈畨瑁咃紝鍒欏畨瑁匘ocker骞跺惎鍔ㄣ��
+
+```text
+[7/9]
+ Start install docker for ubuntu
+ Get docker installer: curl -fsSL https://test.docker.com -o test-docker.sh
+ Get docker run: sudo sh test-docker.sh
+# Executing docker install script, commit: c2de0811708b6d9015ed1a2c80f02c9b70c8ce7b
++ sh -c apt-get update -qq >/dev/null
++ sh -c DEBIAN_FRONTEND=noninteractive apt-get install -y -qq apt-transport-https ca-certificates curl >/dev/null
++ sh -c install -m 0755 -d /etc/apt/keyrings
++ sh -c curl -fsSL "https://download.docker.com/linux/ubuntu/gpg" | gpg --dearmor --yes -o /etc/apt/keyrings/docker.gpg
++ sh -c chmod a+r /etc/apt/keyrings/docker.gpg
++ sh -c echo "deb [arch=amd64 signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu focal test" > /etc/apt/sources.list.d/docker.list
++ sh -c apt-get update -qq >/dev/null
++ sh -c DEBIAN_FRONTEND=noninteractive apt-get install -y -qq docker-ce docker-ce-cli containerd.io docker-compose-plugin docker-ce-rootless-extras docker-buildx-plugin >/dev/null
++ sh -c docker version
+Client: Docker Engine - Community
+ Version: 24.0.2
+
+ ...
+ ...
+
+ Docker install success, start docker server.
+```
+
+##### 涓嬭浇FunASR Docker闀滃儚
+
+涓嬭浇骞舵洿鏂皊tep1.1涓�夋嫨鐨凢unASR Docker闀滃儚銆�
+
+```text
+[8/9]
+ Pull docker image(registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-latest)...
+funasr-runtime-cpu-0.0.1: Pulling from funasr_repo/funasr
+7608715873ec: Pull complete
+3e1014c56f38: Pull complete
+
+ ...
+ ...
+```
+
+##### 鍚姩FunASR Docker
+
+鍚姩FunASR Docker锛岀瓑寰卻tep1.2閫夋嫨鐨勬ā鍨嬩笅杞藉畬鎴愬苟鍚姩FunASR鏈嶅姟銆�
+
+```text
+[9/9]
+ Construct command and run docker ...
+943d8f02b4e5011b71953a0f6c1c1b9bc5aff63e5a96e7406c83e80943b23474
+
+ Loading models:
+ [ASR ][Done ][==================================================][100%][1.10MB/s][v1.2.1]
+ [VAD ][Done ][==================================================][100%][7.26MB/s][v1.2.0]
+ [PUNC][Done ][==================================================][100%][ 474kB/s][v1.1.7]
+ The service has been started.
+ If you want to see an example of how to use the client, you can run sudo bash funasr-runtime-deploy.sh -c .
+```
+
+#### 鍚姩宸茬粡閮ㄧ讲杩囩殑FunASR鏈嶅姟
+涓�閿儴缃插悗鑻ュ嚭鐜伴噸鍚數鑴戠瓑鍏抽棴Docker鐨勫姩浣滐紝鍙�氳繃濡備笅鍛戒护鐩存帴鍚姩FunASR鏈嶅姟锛屽惎鍔ㄩ厤缃负涓婃涓�閿儴缃茬殑璁剧疆銆�
+
+```shell
+sudo bash funasr-runtime-deploy.sh start
+```
+
+#### 鍏抽棴FunASR鏈嶅姟
+
+```shell
+sudo bash funasr-runtime-deploy.sh stop
+```
+
+#### 閲嶅惎FunASR鏈嶅姟
+
+鏍规嵁涓婃涓�閿儴缃茬殑璁剧疆閲嶅惎鍚姩FunASR鏈嶅姟銆�
+```shell
+sudo bash funasr-runtime-deploy.sh restart
+```
+
+#### 鏇挎崲妯″瀷骞堕噸鍚疐unASR鏈嶅姟
+
+鏇挎崲姝e湪浣跨敤鐨勬ā鍨嬶紝骞堕噸鏂板惎鍔‵unASR鏈嶅姟銆傛ā鍨嬮渶涓篗odelScope涓殑ASR/VAD/PUNC妯″瀷锛屾垨鑰呬粠ModelScope涓ā鍨媐inetune鍚庣殑妯″瀷銆�
+
+```shell
+sudo bash funasr-runtime-deploy.sh update model <model ID>
+
+e.g
+sudo bash funasr-runtime-deploy.sh update model damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+```
+
+### 娴嬭瘯涓庝娇鐢ㄧ绾挎枃浠惰浆鍐欐湇鍔�
+
+鍦ㄦ湇鍔″櫒涓婂畬鎴怓unASR鏈嶅姟閮ㄧ讲浠ュ悗锛屽彲浠ラ�氳繃濡備笅鐨勬楠ゆ潵娴嬭瘯鍜屼娇鐢ㄧ绾挎枃浠惰浆鍐欐湇鍔°�傜洰鍓嶅垎鍒敮鎸丳ython銆丆++銆丣ava鐗堟湰client鐨勭殑鍛戒护琛岃繍琛岋紝浠ュ強鍙湪娴忚鍣ㄥ彲鐩存帴浣撻獙鐨刪tml缃戦〉鐗堟湰锛屾洿澶氳瑷�client鏀寔鍙傝�冩枃妗c�怓unASR楂橀樁寮�鍙戞寚鍗椼�戙��
+funasr-runtime-deploy.sh杩愯缁撴潫鍚庯紝鍙�氳繃鍛戒护浠ヤ氦浜掔殑褰㈠紡鑷姩涓嬭浇娴嬭瘯鏍蜂緥samples鍒板綋鍓嶇洰褰曠殑funasr_samples涓紝骞惰缃弬鏁拌繍琛岋細
+
+```shell
+sudo bash funasr-runtime-deploy.sh client
+```
+
+鍙�夋嫨鎻愪緵鐨凱ython鍜孡inux C++鑼冧緥绋嬪簭锛屼互Python鑼冧緥涓轰緥锛�
+
+```text
+Will download sample tools for the client to show how speech recognition works.
+ Please select the client you want to run.
+ 1) Python
+ 2) Linux_Cpp
+ Enter your choice: 1
+
+ Please enter the IP of server, default(127.0.0.1):
+ Please enter the port of server, default(10095):
+ Please enter the audio path, default(/root/funasr_samples/audio/asr_example.wav):
+
+ Run pip3 install click>=8.0.4
+Looking in indexes: http://mirrors.cloud.aliyuncs.com/pypi/simple/
+Requirement already satisfied: click>=8.0.4 in /usr/local/lib/python3.8/dist-packages (8.1.3)
+
+ Run pip3 install -r /root/funasr_samples/python/requirements_client.txt
+Looking in indexes: http://mirrors.cloud.aliyuncs.com/pypi/simple/
+Requirement already satisfied: websockets in /usr/local/lib/python3.8/dist-packages (from -r /root/funasr_samples/python/requirements_client.txt (line 1)) (11.0.3)
+
+ Run python3 /root/funasr_samples/python/wss_client_asr.py --host 127.0.0.1 --port 10095 --mode offline --audio_in /root/funasr_samples/audio/asr_example.wav --send_without_sleep --output_dir ./funasr_samples/python
+
+ ...
+ ...
+
+ pid0_0: 娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨嬨��
+Exception: sent 1000 (OK); then received 1000 (OK)
+end
+
+ If failed, you can try (python3 /root/funasr_samples/python/wss_client_asr.py --host 127.0.0.1 --port 10095 --mode offline --audio_in /root/funasr_samples/audio/asr_example.wav --send_without_sleep --output_dir ./funasr_samples/python) in your Shell.
+
+```
+
+#### python-client
+鑻ユ兂鐩存帴杩愯client杩涜娴嬭瘯锛屽彲鍙傝�冨涓嬬畝鏄撹鏄庯紝浠ython鐗堟湰涓轰緥锛�
+
+```shell
+python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.wav" --send_without_sleep --output_dir "./results"
+```
+
+鍛戒护鍙傛暟璇存槑锛�
+```text
+--host 涓篎unASR runtime-SDK鏈嶅姟閮ㄧ讲鏈哄櫒ip锛岄粯璁や负鏈満ip锛�127.0.0.1锛夛紝濡傛灉client涓庢湇鍔′笉鍦ㄥ悓涓�鍙版湇鍔″櫒锛岄渶瑕佹敼涓洪儴缃叉満鍣╥p
+--port 10095 閮ㄧ讲绔彛鍙�
+--mode offline琛ㄧず绂荤嚎鏂囦欢杞啓
+--audio_in 闇�瑕佽繘琛岃浆鍐欑殑闊抽鏂囦欢锛屾敮鎸佹枃浠惰矾寰勶紝鏂囦欢鍒楄〃wav.scp
+--output_dir 璇嗗埆缁撴灉淇濆瓨璺緞
+```
+
+#### cpp-client
+
+```shell
+export LD_LIBRARY_PATH=/root/funasr_samples/cpp/libs:$LD_LIBRARY_PATH
+/root/funasr_samples/cpp/funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path /root/funasr_samples/audio/asr_example.wav
+```
+
+鍛戒护鍙傛暟璇存槑锛�
+
+```text
+--server-ip 涓篎unASR runtime-SDK鏈嶅姟閮ㄧ讲鏈哄櫒ip锛岄粯璁や负鏈満ip锛�127.0.0.1锛夛紝濡傛灉client涓庢湇鍔′笉鍦ㄥ悓涓�鍙版湇鍔″櫒锛岄渶瑕佹敼涓洪儴缃叉満鍣╥p
+--port 10095 閮ㄧ讲绔彛鍙�
+--wav-path 闇�瑕佽繘琛岃浆鍐欑殑闊抽鏂囦欢锛屾敮鎸佹枃浠惰矾寰�
+```
+
+### 瑙嗛demo
+
+[鐐瑰嚮姝ゅ]()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/funasr/runtime/docs/aliyun_server_tutorial.md b/funasr/runtime/docs/aliyun_server_tutorial.md
new file mode 100644
index 0000000..e68d8bf
--- /dev/null
+++ b/funasr/runtime/docs/aliyun_server_tutorial.md
@@ -0,0 +1,74 @@
+# 浜戞湇鍔″櫒鐢宠鏁欑▼
+
+鎴戜滑浠ラ樋閲屼簯锛圼鐐规閾炬帴](https://www.aliyun.com/)锛変负渚嬶紝婕旂ず濡備綍鐢宠浜戞湇鍔″櫒
+
+## 鏈嶅姟鍣ㄩ厤缃�
+
+鐢ㄦ埛鍙互鏍规嵁鑷繁鐨勪笟鍔¢渶姹傦紝閫夋嫨鍚堥�傜殑鏈嶅姟鍣ㄩ厤缃紝鎺ㄨ崘閰嶇疆涓猴細
+- 閰嶇疆涓�锛堥珮閰嶏級锛歑86鏋舵瀯锛�32/64鏍�8369CPU锛屽唴瀛�8G浠ヤ笂锛�
+- 閰嶇疆浜岋細X86鏋舵瀯锛�32/64鏍�8163CPU锛屽唴瀛�8G浠ヤ笂锛�
+
+璇︾粏鎬ц兘娴嬭瘯鎶ュ憡锛歔鐐规閾炬帴](./benchmark_onnx_cpp.md)
+
+鎴戜滑浠ュ厤璐硅瘯鐢紙1锝�3涓湀锛変负渚嬶紝婕旂ず濡備綍鐢宠鏈嶅姟鍣ㄦ祦绋嬶紝鍥炬枃姝ラ濡備笅锛�
+
+### 鐧婚檰涓汉璐﹀彿
+鎵撳紑闃块噷浜戝畼缃慬鐐规閾炬帴](https://www.aliyun.com/)锛屾敞鍐屽苟鐧婚檰涓汉璐﹀彿锛屽涓嬪浘鏍囧彿1鎵�绀�
+
+<img src="images/aliyun1.png" width="900"/>
+
+### 鍏嶈垂璇曠敤
+
+鐐瑰嚮濡備笂鍥炬墍浠ユ爣鍙�2锛屽嚭鐜板涓嬬晫闈�
+
+<img src="images/aliyun2.png" width="900"/>
+
+鍐嶇偣鍑绘爣鍙�3锛屽嚭鐜板涓嬬晫闈�
+
+<img src="images/aliyun3.png" width="900"/>
+
+### 鐢宠ECS瀹炰緥
+
+涓汉璐﹀彿鍙互鍏嶈垂璇曠敤1鏍�2GB鍐呭瓨锛屾瘡鏈�750灏忔椂锛屼紒涓氳璇佸悗锛屽彲浠ュ厤璐硅瘯鐢�2鏍�8GB鍐呭瓨 3涓湀锛屾牴鎹处鍙锋儏鍐碉紝鐐瑰嚮涓婂浘涓爣鍙�4锛屽嚭鐜板涓嬬晫闈細
+
+<img src="images/aliyun4.png" width="900"/>
+
+渚濇鎸夌収涓婂浘鏍囧彿5銆�6銆�7閫夋嫨鍚庯紝鐐瑰嚮绔嬪嵆璇曠敤锛屽嚭鐜板涓嬬晫闈�
+
+<img src="images/aliyun5.png" width="900"/>
+
+### 寮�鏀炬湇鍔$鍙�
+
+鐐瑰嚮瀹夊叏缁勶紙鏍囧彿9锛夛紝鍑虹幇濡備笅鐣岄潰
+
+<img src="images/aliyun6.png" width="900"/>
+
+鍐嶇偣鍑绘爣鍙�10锛屽嚭鐜板涓嬬晫闈�
+
+<img src="images/aliyun7.png" width="900"/>
+
+鐐瑰嚮鎵嬪姩娣诲姞锛堟爣鍙�11锛夛紝鍒嗗埆鎸夌収鏍囧彿12銆�13濉叆鍐呭锛屽悗鐐瑰嚮淇濆瓨锛堟爣鍙�14锛夛紝鍐嶇偣鍑诲疄渚嬶紙鏍囧彿15锛夛紝鍑虹幇濡備笅鐣岄潰
+
+<img src="images/aliyun8.png" width="900"/>
+
+### 鍚姩ECS绀轰緥
+
+鐐瑰嚮绀轰緥鍚嶇О锛堟爣鍙�16锛夛紝鍑虹幇濡備笅椤甸潰
+
+<img src="images/aliyun9.png" width="900"/>
+
+鐐瑰嚮杩滅▼鍚姩锛堟爣鍙�17锛夛紝鍑虹幇椤甸潰鍚庯紝鐐瑰嚮绔嬪嵆鐧婚檰锛屽嚭鐜板涓嬬晫闈�
+
+<img src="images/aliyun10.png" width="900"/>
+
+棣栨鐧婚檰闇�瑕佺偣鍑婚噸缃瘑鐮侊紙涓婂浘涓豢鑹茬澶达級锛岃缃ソ瀵嗙爜鍚庯紝杈撳叆瀵嗙爜锛堟爣鍙�18锛夛紝鐐瑰嚮纭锛堟爣鍙�19锛�
+
+<img src="images/aliyun11.png" width="900"/>
+
+棣栨鐧婚檰浼氶亣鍒颁笂鍥炬墍绀洪棶棰橈紝鐐瑰嚮鏍囧彿20锛屾牴鎹枃妗f搷浣滃悗锛岄噸鏂扮櫥闄嗭紝鐧婚檰鎴愬姛鍚庡嚭鐜板涓嬮〉闈�
+
+<img src="images/aliyun12.png" width="900"/>
+
+涓婂浘琛ㄧず宸茬粡鎴愬姛鐢宠浜嗕簯鏈嶅姟鍣紝鍚庣画鍙互鏍规嵁FunASR runtime-SDK閮ㄧ讲鏂囨。杩涜涓�閿儴缃诧紙[鐐瑰嚮姝ゅ]()锛�
+
+
diff --git a/funasr/runtime/python/benchmark_libtorch.md b/funasr/runtime/docs/benchmark_libtorch.md
similarity index 100%
rename from funasr/runtime/python/benchmark_libtorch.md
rename to funasr/runtime/docs/benchmark_libtorch.md
diff --git a/funasr/runtime/python/benchmark_onnx.md b/funasr/runtime/docs/benchmark_onnx.md
similarity index 100%
rename from funasr/runtime/python/benchmark_onnx.md
rename to funasr/runtime/docs/benchmark_onnx.md
diff --git a/funasr/runtime/python/benchmark_onnx_cpp.md b/funasr/runtime/docs/benchmark_onnx_cpp.md
similarity index 100%
rename from funasr/runtime/python/benchmark_onnx_cpp.md
rename to funasr/runtime/docs/benchmark_onnx_cpp.md
diff --git a/funasr/runtime/docs/images/aliyun1.png b/funasr/runtime/docs/images/aliyun1.png
new file mode 100644
index 0000000..f9a29d7
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun1.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun10.png b/funasr/runtime/docs/images/aliyun10.png
new file mode 100644
index 0000000..899a9f0
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun10.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun11.png b/funasr/runtime/docs/images/aliyun11.png
new file mode 100644
index 0000000..2023365
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun11.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun12.png b/funasr/runtime/docs/images/aliyun12.png
new file mode 100644
index 0000000..f1bb790
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun12.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun2.png b/funasr/runtime/docs/images/aliyun2.png
new file mode 100644
index 0000000..6d44e37
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun2.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun3.png b/funasr/runtime/docs/images/aliyun3.png
new file mode 100644
index 0000000..6787bef
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun3.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun4.png b/funasr/runtime/docs/images/aliyun4.png
new file mode 100644
index 0000000..d199500
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun4.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun5.png b/funasr/runtime/docs/images/aliyun5.png
new file mode 100644
index 0000000..42914f9
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun5.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun6.png b/funasr/runtime/docs/images/aliyun6.png
new file mode 100644
index 0000000..92f1def
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun6.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun7.png b/funasr/runtime/docs/images/aliyun7.png
new file mode 100644
index 0000000..ec90994
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun7.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun8.png b/funasr/runtime/docs/images/aliyun8.png
new file mode 100644
index 0000000..b7719a0
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun8.png
Binary files differ
diff --git a/funasr/runtime/docs/images/aliyun9.png b/funasr/runtime/docs/images/aliyun9.png
new file mode 100644
index 0000000..f62dba8
--- /dev/null
+++ b/funasr/runtime/docs/images/aliyun9.png
Binary files differ
diff --git a/funasr/runtime/html5/readme.md b/funasr/runtime/html5/readme.md
index 930aa88..0c1eba0 100644
--- a/funasr/runtime/html5/readme.md
+++ b/funasr/runtime/html5/readme.md
@@ -41,7 +41,7 @@
`Tips:` asr service and html5 service should be deployed on the same device.
```shell
cd ../python/websocket
-python wss_srv_asr.py --port 1095
+python wss_srv_asr.py --port 10095
```
@@ -51,8 +51,18 @@
# https://30.220.136.139:1337/static/index.html
```
-### modify asr address in html according to your environment
-asr address in index.html must be wss
+### open browser to open html5 file directly without h5Server
+you can run html5 client by just clicking the index.html file directly in your computer.
+1) lauch asr service without ssl, it must be in ws mode as ssl protocol will prohibit such access.
+2) copy whole directory /funasr/runtime/html5/static to your computer
+3) open /funasr/runtime/html5/static/index.html by browser
+4) enter asr service ws address and connect
+
+
+```shell
+
+```
+
## Acknowledge
diff --git a/funasr/runtime/html5/readme_cn.md b/funasr/runtime/html5/readme_cn.md
index 73bf1b0..b859387 100644
--- a/funasr/runtime/html5/readme_cn.md
+++ b/funasr/runtime/html5/readme_cn.md
@@ -49,7 +49,7 @@
#### wss鏂瑰紡
```shell
cd ../python/websocket
-python wss_srv_asr.py --port 1095
+python wss_srv_asr.py --port 10095
```
### 娴忚鍣ㄦ墦寮�鍦板潃
diff --git a/funasr/runtime/html5/static/index.html b/funasr/runtime/html5/static/index.html
index b99a140..2c76d82 100644
--- a/funasr/runtime/html5/static/index.html
+++ b/funasr/runtime/html5/static/index.html
@@ -14,21 +14,46 @@
<h1>FunASR Demo</h1>
+ <h3>杩欓噷鏄疐unASR寮�婧愰」鐩綋楠宒emo锛岄泦鎴愪簡VAD銆丄SR涓庢爣鐐圭瓑宸ヤ笟绾у埆鐨勬ā鍨嬶紝鏀寔闀块煶棰戠绾挎枃浠惰浆鍐欙紝瀹炴椂璇煶璇嗗埆绛夛紝寮�婧愰」鐩湴鍧�锛歨ttps://github.com/alibaba-damo-academy/FunASR</h3>
+
<div class="div_class_topArea">
<div class="div_class_recordControl">
asr鏈嶅姟鍣ㄥ湴鍧�(蹇呭~):
<br>
- <input id="wssip" type="text" style=" width: 100%;height:100%" value="wss://127.0.0.1:1095/"/>
+ <input id="wssip" type="text" onchange="addresschange()" style=" width: 100%;height:100%" value="wss://127.0.0.1:10095/"/>
<br>
+ <a id="wsslink" href="#" onclick="window.open('https://127.0.0.1:10095/', '_blank')"><div id="info_wslink">鐐规澶勬墜宸ユ巿鏉僿ss://127.0.0.1:10095/</div></a>
<br>
- <div style="border:2px solid #ccc;">
+ <br>
+ <div style="border:2px solid #ccc;">
+ 閫夋嫨褰曢煶妯″紡:<br/>
+
+ <label><input name="recoder_mode" onclick="on_recoder_mode_change()" type="radio" value="mic" checked="true"/>楹﹀厠椋� </label>
+ <label><input name="recoder_mode" onclick="on_recoder_mode_change()" type="radio" value="file" />鏂囦欢 </label>
+
+ </div>
+
+ <br>
+ <div id="mic_mode_div" style="border:2px solid #ccc;display:block;">
閫夋嫨asr妯″瀷妯″紡:<br/>
+
<label><input name="asr_mode" type="radio" value="2pass" checked="true"/>2pass </label>
<label><input name="asr_mode" type="radio" value="online" />online </label>
- <label><input name="asr_mode" type="radio" value="offline" />offline </label>
+ <label><input name="asr_mode" type="radio" value="offline" />offline </label>
+
+ </div>
+
+ <div id="rec_mode_div" style="border:2px solid #ccc;display:none;">
+
+
+ <input type="file" id="upfile">
+
</div>
<br>
+
+
+
璇煶璇嗗埆缁撴灉鏄剧ず锛�
<br>
@@ -36,6 +61,7 @@
<br>
<div id="info_div">璇风偣鍑诲紑濮�</div>
<div class="div_class_buttons">
+ <button id="btnConnect">杩炴帴</button>
<button id="btnStart">寮�濮�</button>
<button id="btnStop">鍋滄</button>
diff --git a/funasr/runtime/html5/static/main.js b/funasr/runtime/html5/static/main.js
index 9317778..4a50801 100644
--- a/funasr/runtime/html5/static/main.js
+++ b/funasr/runtime/html5/static/main.js
@@ -23,22 +23,150 @@
var sampleBuf=new Int16Array();
// 瀹氫箟鎸夐挳鍝嶅簲浜嬩欢
var btnStart = document.getElementById('btnStart');
-btnStart.onclick = start;
+btnStart.onclick = record;
var btnStop = document.getElementById('btnStop');
btnStop.onclick = stop;
btnStop.disabled = true;
+btnStart.disabled = true;
+btnConnect= document.getElementById('btnConnect');
+btnConnect.onclick = start;
+
+var awsslink= document.getElementById('wsslink');
-var rec_text="";
-var offline_text="";
+var rec_text=""; // for online rec asr result
+var offline_text=""; // for offline rec asr result
var info_div = document.getElementById('info_div');
-//var now_ipaddress=window.location.href;
-//now_ipaddress=now_ipaddress.replace("https://","wss://");
-//now_ipaddress=now_ipaddress.replace("static/index.html","");
-//document.getElementById('wssip').value=now_ipaddress;
+var upfile = document.getElementById('upfile');
+
+
+var isfilemode=false; // if it is in file mode
+var file_data_array; // array to save file data
+
+var totalsend=0;
+
+
+var now_ipaddress=window.location.href;
+now_ipaddress=now_ipaddress.replace("https://","wss://");
+now_ipaddress=now_ipaddress.replace("static/index.html","");
+var localport=window.location.port;
+now_ipaddress=now_ipaddress.replace(localport,"10095");
+document.getElementById('wssip').value=now_ipaddress;
+addresschange();
+function addresschange()
+{
+
+ var Uri = document.getElementById('wssip').value;
+ document.getElementById('info_wslink').innerHTML="鐐规澶勬墜宸ユ巿鏉冿紙IOS鎵嬫満锛�";
+ Uri=Uri.replace(/wss/g,"https");
+ console.log("addresschange uri=",Uri);
+
+ awsslink.onclick=function(){
+ window.open(Uri, '_blank');
+ }
+
+}
+upfile.onclick=function()
+{
+ btnStart.disabled = true;
+ btnStop.disabled = true;
+ btnConnect.disabled=false;
+
+}
+upfile.onchange = function () {
+銆�銆�銆�銆�銆�銆�var len = this.files.length;
+ for(let i = 0; i < len; i++) {
+ let fileAudio = new FileReader();
+ fileAudio.readAsArrayBuffer(this.files[i]);
+ fileAudio.onload = function() {
+ var audioblob= fileAudio.result;
+ file_data_array=audioblob;
+ console.log(audioblob);
+
+ info_div.innerHTML='璇风偣鍑昏繛鎺ヨ繘琛岃瘑鍒�';
+
+ }
+銆�銆�銆�銆�銆�銆�銆�銆�銆�銆�fileAudio.onerror = function(e) {
+銆�銆�銆�銆�銆�銆�銆�銆�銆�銆�銆�銆�console.log('error' + e);
+銆�銆�銆�銆�銆�銆�銆�銆�銆�銆�}
+ }
+ }
+
+function play_file()
+{
+ var audioblob=new Blob( [ new Uint8Array(file_data_array)] , {type :"audio/wav"});
+ var audio_record = document.getElementById('audio_record');
+ audio_record.src = (window.URL||webkitURL).createObjectURL(audioblob);
+ audio_record.controls=true;
+ //audio_record.play(); //not auto play
+}
+function start_file_send()
+{
+ sampleBuf=new Int16Array( file_data_array );
+
+ var chunk_size=960; // for asr chunk_size [5, 10, 5]
+
+
+
+
+
+ while(sampleBuf.length>=chunk_size){
+
+ sendBuf=sampleBuf.slice(0,chunk_size);
+ totalsend=totalsend+sampleBuf.length;
+ sampleBuf=sampleBuf.slice(chunk_size,sampleBuf.length);
+ wsconnecter.wsSend(sendBuf);
+
+
+ }
+
+ stop();
+
+
+
+}
+
+
+function on_recoder_mode_change()
+{
+ var item = null;
+ var obj = document.getElementsByName("recoder_mode");
+ for (var i = 0; i < obj.length; i++) { //閬嶅巻Radio
+ if (obj[i].checked) {
+ item = obj[i].value;
+ break;
+ }
+
+
+ }
+ if(item=="mic")
+ {
+ document.getElementById("mic_mode_div").style.display = 'block';
+ document.getElementById("rec_mode_div").style.display = 'none';
+
+
+ btnStart.disabled = true;
+ btnStop.disabled = true;
+ btnConnect.disabled=false;
+ isfilemode=false;
+ }
+ else
+ {
+ document.getElementById("mic_mode_div").style.display = 'none';
+ document.getElementById("rec_mode_div").style.display = 'block';
+
+ btnStart.disabled = true;
+ btnStop.disabled = true;
+ btnConnect.disabled=true;
+ isfilemode=true;
+ info_div.innerHTML='璇风偣鍑婚�夋嫨鏂囦欢';
+
+
+ }
+}
function getAsrMode(){
var item = null;
@@ -51,7 +179,12 @@
}
+ if(isfilemode)
+ {
+ item= "offline";
+ }
console.log("asr mode"+item);
+
return item;
}
@@ -64,41 +197,80 @@
var asrmodel=JSON.parse(jsonMsg.data)['mode'];
if(asrmodel=="2pass-offline")
{
- offline_text=offline_text+rectxt.replace(/ +/g,"");
+ offline_text=offline_text+rectxt; //.replace(/ +/g,"");
rec_text=offline_text;
}
else
{
- rec_text=rec_text+rectxt.replace(/ +/g,"");
+ rec_text=rec_text+rectxt; //.replace(/ +/g,"");
}
var varArea=document.getElementById('varArea');
varArea.value=rec_text;
+ console.log( "offline_text: " + asrmodel+","+offline_text);
+ console.log( "rec_text: " + rec_text);
+ if (isfilemode==true){
+ console.log("call stop ws!");
+ play_file();
+ wsconnecter.wsStop();
+
+ info_div.innerHTML="璇风偣鍑昏繛鎺�";
+
+ btnStart.disabled = true;
+ btnStop.disabled = true;
+ btnConnect.disabled=false;
+ }
+
}
// 杩炴帴鐘舵�佸搷搴�
function getConnState( connState ) {
- if ( connState === 0 ) {
+ if ( connState === 0 ) { //on open
- rec.open( function(){
- rec.start();
- console.log("寮�濮嬪綍闊�");
- });
+ info_div.innerHTML='杩炴帴鎴愬姛!璇风偣鍑诲紑濮�';
+ if (isfilemode==true){
+ info_div.innerHTML='璇疯�愬績绛夊緟,澶ф枃浠剁瓑寰呮椂闂存洿闀�';
+ start_file_send();
+ }
+ else
+ {
+ btnStart.disabled = false;
+ btnStop.disabled = true;
+ btnConnect.disabled=true;
+ }
} else if ( connState === 1 ) {
//stop();
} else if ( connState === 2 ) {
stop();
console.log( 'connecttion error' );
- alert("杩炴帴鍦板潃"+document.getElementById('wssip').value+"澶辫触,璇锋鏌sr鍦板潃鍜岀鍙o紝骞剁‘淇漢5鏈嶅姟鍜宎sr鏈嶅姟鍦ㄥ悓涓�涓煙鍐呫�傛垨鎹釜娴忚鍣ㄨ瘯璇曘��");
+ alert("杩炴帴鍦板潃"+document.getElementById('wssip').value+"澶辫触,璇锋鏌sr鍦板潃鍜岀鍙c�傛垨璇曡瘯鐣岄潰涓婃墜鍔ㄦ巿鏉冿紝鍐嶈繛鎺ャ��");
btnStart.disabled = true;
- info_div.innerHTML='璇风偣鍑诲紑濮�';
+ btnStop.disabled = true;
+ btnConnect.disabled=false;
+
+
+ info_div.innerHTML='璇风偣鍑昏繛鎺�';
}
}
+function record()
+{
+
+ rec.open( function(){
+ rec.start();
+ console.log("寮�濮�");
+ btnStart.disabled = true;
+ btnStop.disabled = false;
+ btnConnect.disabled=true;
+ });
+
+}
+
+
// 璇嗗埆鍚姩銆佸仠姝€�佹竻绌烘搷浣�
function start() {
@@ -106,15 +278,28 @@
// 娓呴櫎鏄剧ず
clear();
//鎺т欢鐘舵�佹洿鏂�
-
-
+ console.log("isfilemode"+isfilemode);
+
//鍚姩杩炴帴
var ret=wsconnecter.wsStart();
+ // 1 is ok, 0 is error
if(ret==1){
+ info_div.innerHTML="姝e湪杩炴帴asr鏈嶅姟鍣紝璇风瓑寰�...";
isRec = true;
btnStart.disabled = true;
- btnStop.disabled = false;
- info_div.innerHTML="姝e湪杩炴帴asr鏈嶅姟鍣紝璇风瓑寰�...";
+ btnStop.disabled = true;
+ btnConnect.disabled=true;
+
+ return 1;
+ }
+ else
+ {
+ info_div.innerHTML="璇风偣鍑诲紑濮�";
+ btnStart.disabled = true;
+ btnStop.disabled = true;
+ btnConnect.disabled=false;
+
+ return 0;
}
}
@@ -130,24 +315,35 @@
};
console.log(request);
if(sampleBuf.length>0){
- wsconnecter.wsSend(sampleBuf,false);
+ wsconnecter.wsSend(sampleBuf);
console.log("sampleBuf.length"+sampleBuf.length);
sampleBuf=new Int16Array();
}
- wsconnecter.wsSend( JSON.stringify(request) ,false);
+ wsconnecter.wsSend( JSON.stringify(request) );
+
+
-
-
-
+
// 鎺т欢鐘舵�佹洿鏂�
+
isRec = false;
- info_div.innerHTML="璇风瓑鍊�...";
- btnStop.disabled = true;
- setTimeout(function(){
+ info_div.innerHTML="鍙戦�佸畬鏁版嵁,璇风瓑鍊�,姝e湪璇嗗埆...";
+
+ if(isfilemode==false){
+ btnStop.disabled = true;
+ btnStart.disabled = true;
+ btnConnect.disabled=true;
+ //wait 3s for asr result
+ setTimeout(function(){
console.log("call stop ws!");
- wsconnecter.wsStop();btnStart.disabled = false;info_div.innerHTML="璇风偣鍑诲紑濮�";}, 3000 );
+ wsconnecter.wsStop();
+ btnConnect.disabled=false;
+ info_div.innerHTML="璇风偣鍑昏繛鎺�";}, 3000 );
+
+
+
rec.stop(function(blob,duration){
console.log(blob);
@@ -157,7 +353,7 @@
var audio_record = document.getElementById('audio_record');
audio_record.src = (window.URL||webkitURL).createObjectURL(theblob);
audio_record.controls=true;
- audio_record.play();
+ //audio_record.play();
} ,function(msg){
@@ -170,8 +366,9 @@
},function(errMsg){
console.log("errMsg: " + errMsg);
});
+ }
// 鍋滄杩炴帴
-
+
}
@@ -200,7 +397,7 @@
while(sampleBuf.length>=chunk_size){
sendBuf=sampleBuf.slice(0,chunk_size);
sampleBuf=sampleBuf.slice(chunk_size,sampleBuf.length);
- wsconnecter.wsSend(sendBuf,false);
+ wsconnecter.wsSend(sendBuf);
diff --git a/funasr/runtime/html5/static/wsconnecter.js b/funasr/runtime/html5/static/wsconnecter.js
index 676a94a..2873022 100644
--- a/funasr/runtime/html5/static/wsconnecter.js
+++ b/funasr/runtime/html5/static/wsconnecter.js
@@ -15,8 +15,7 @@
this.wsStart = function () {
var Uri = document.getElementById('wssip').value; //"wss://111.205.137.58:5821/wss/" //璁剧疆wss asr online鎺ュ彛鍦板潃 濡� wss://X.X.X.X:port/wss/
-
- if(Uri.match(/wss:\S*/))
+ if(Uri.match(/wss:\S*|ws:\S*/))
{
console.log("Uri"+Uri);
}
@@ -25,12 +24,13 @@
alert("璇锋鏌ss鍦板潃姝g‘鎬�");
return 0;
}
+
if ( 'WebSocket' in window ) {
speechSokt = new WebSocket( Uri ); // 瀹氫箟socket杩炴帴瀵硅薄
speechSokt.onopen = function(e){onOpen(e);}; // 瀹氫箟鍝嶅簲鍑芥暟
speechSokt.onclose = function(e){
console.log("onclose ws!");
- speechSokt.close();
+ //speechSokt.close();
onClose(e);
};
speechSokt.onmessage = function(e){onMessage(e);};
@@ -51,16 +51,13 @@
}
};
- this.wsSend = function ( oneData,stop ) {
+ this.wsSend = function ( oneData ) {
if(speechSokt == undefined) return;
if ( speechSokt.readyState === 1 ) { // 0:CONNECTING, 1:OPEN, 2:CLOSING, 3:CLOSED
speechSokt.send( oneData );
- if(stop){
- setTimeout(speechSokt.close(), 3000 );
- }
}
};
@@ -80,6 +77,7 @@
speechSokt.send( JSON.stringify(request) );
console.log("杩炴帴鎴愬姛");
stateHandle(0);
+
}
function onClose( e ) {
@@ -92,9 +90,11 @@
}
function onError( e ) {
+
info_div.innerHTML="杩炴帴"+e;
console.log(e);
stateHandle(2);
+
}
diff --git a/funasr/runtime/java/FunasrWsClient.java b/funasr/runtime/java/FunasrWsClient.java
new file mode 100644
index 0000000..ec55c94
--- /dev/null
+++ b/funasr/runtime/java/FunasrWsClient.java
@@ -0,0 +1,344 @@
+//
+// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+// Reserved. MIT License (https://opensource.org/licenses/MIT)
+//
+/*
+ * // 2022-2023 by zhaomingwork@qq.com
+ */
+// java FunasrWsClient
+// usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
+// [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
+package websocket;
+
+import java.io.*;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.*;
+import java.util.Map;
+import net.sourceforge.argparse4j.ArgumentParsers;
+import net.sourceforge.argparse4j.inf.ArgumentParser;
+import net.sourceforge.argparse4j.inf.ArgumentParserException;
+import net.sourceforge.argparse4j.inf.Namespace;
+import org.java_websocket.client.WebSocketClient;
+import org.java_websocket.drafts.Draft;
+import org.java_websocket.handshake.ServerHandshake;
+import org.json.simple.JSONArray;
+import org.json.simple.JSONObject;
+import org.json.simple.parser.JSONParser;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** This example demonstrates how to connect to websocket server. */
+public class FunasrWsClient extends WebSocketClient {
+
+ public class RecWavThread extends Thread {
+ private FunasrWsClient funasrClient;
+
+ public RecWavThread(FunasrWsClient funasrClient) {
+ this.funasrClient = funasrClient;
+ }
+
+ public void run() {
+ this.funasrClient.recWav();
+ }
+ }
+
+ private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
+
+ public FunasrWsClient(URI serverUri, Draft draft) {
+ super(serverUri, draft);
+ }
+
+ public FunasrWsClient(URI serverURI) {
+ super(serverURI);
+ }
+
+ public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
+ super(serverUri, httpHeaders);
+ }
+
+ public void getSslContext(String keyfile, String certfile) {
+ // TODO
+ return;
+ }
+
+ // send json at first time
+ public void sendJson(
+ String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
+ try {
+
+ JSONObject obj = new JSONObject();
+ obj.put("mode", mode);
+ JSONArray array = new JSONArray();
+ String[] chunkList = strChunkSize.split(",");
+ for (int i = 0; i < chunkList.length; i++) {
+ array.add(Integer.valueOf(chunkList[i].trim()));
+ }
+
+ obj.put("chunk_size", array);
+ obj.put("chunk_interval", new Integer(chunkInterval));
+ obj.put("wav_name", wavName);
+ if (isSpeaking) {
+ obj.put("is_speaking", new Boolean(true));
+ } else {
+ obj.put("is_speaking", new Boolean(false));
+ }
+ logger.info("sendJson: " + obj);
+ // return;
+
+ send(obj.toString());
+
+ return;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // send json at end of wav
+ public void sendEof() {
+ try {
+ JSONObject obj = new JSONObject();
+
+ obj.put("is_speaking", new Boolean(false));
+
+ logger.info("sendEof: " + obj);
+ // return;
+
+ send(obj.toString());
+ iseof = true;
+ return;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // function for rec wav file
+ public void recWav() {
+ sendJson(mode, strChunkSize, chunkInterval, wavName, true);
+ File file = new File(FunasrWsClient.wavPath);
+
+ int chunkSize = sendChunkSize;
+ byte[] bytes = new byte[chunkSize];
+
+ int readSize = 0;
+ try (FileInputStream fis = new FileInputStream(file)) {
+ if (FunasrWsClient.wavPath.endsWith(".wav")) {
+ fis.read(bytes, 0, 44); //skip first 44 wav header
+ }
+ readSize = fis.read(bytes, 0, chunkSize);
+ while (readSize > 0) {
+ // send when it is chunk size
+ if (readSize == chunkSize) {
+ send(bytes); // send buf to server
+
+ } else {
+ // send when at last or not is chunk size
+ byte[] tmpBytes = new byte[readSize];
+ for (int i = 0; i < readSize; i++) {
+ tmpBytes[i] = bytes[i];
+ }
+ send(tmpBytes);
+ }
+ // if not in offline mode, we simulate online stream by sleep
+ if (!mode.equals("offline")) {
+ Thread.sleep(Integer.valueOf(chunkSize / 32));
+ }
+
+ readSize = fis.read(bytes, 0, chunkSize);
+ }
+
+ if (!mode.equals("offline")) {
+ // if not offline, we send eof and wait for 3 seconds to close
+ Thread.sleep(2000);
+ sendEof();
+ Thread.sleep(3000);
+ close();
+ } else {
+ // if offline, just send eof
+ sendEof();
+ }
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onOpen(ServerHandshake handshakedata) {
+
+ RecWavThread thread = new RecWavThread(this);
+ thread.start();
+ }
+
+ @Override
+ public void onMessage(String message) {
+ JSONObject jsonObject = new JSONObject();
+ JSONParser jsonParser = new JSONParser();
+ logger.info("received: " + message);
+ try {
+ jsonObject = (JSONObject) jsonParser.parse(message);
+ logger.info("text: " + jsonObject.get("text"));
+ } catch (org.json.simple.parser.ParseException e) {
+ e.printStackTrace();
+ }
+ if (iseof && mode.equals("offline")) {
+ close();
+ }
+ }
+
+ @Override
+ public void onClose(int code, String reason, boolean remote) {
+
+ logger.info(
+ "Connection closed by "
+ + (remote ? "remote peer" : "us")
+ + " Code: "
+ + code
+ + " Reason: "
+ + reason);
+ }
+
+ @Override
+ public void onError(Exception ex) {
+ logger.info("ex: " + ex);
+ ex.printStackTrace();
+ // if the error is fatal then onClose will be called additionally
+ }
+
+ private boolean iseof = false;
+ public static String wavPath;
+ static String mode = "online";
+ static String strChunkSize = "5,10,5";
+ static int chunkInterval = 10;
+ static int sendChunkSize = 1920;
+
+ String wavName = "javatest";
+
+ public static void main(String[] args) throws URISyntaxException {
+ ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
+ parser
+ .addArgument("--port")
+ .help("Port on which to listen.")
+ .setDefault("8889")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--host")
+ .help("the IP address of server.")
+ .setDefault("127.0.0.1")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--audio_in")
+ .help("wav path for decoding.")
+ .setDefault("asr_example.wav")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--num_threads")
+ .help("num of threads for test.")
+ .setDefault(1)
+ .type(Integer.class)
+ .required(false);
+ parser
+ .addArgument("--chunk_size")
+ .help("chunk size for asr.")
+ .setDefault("5, 10, 5")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--chunk_interval")
+ .help("chunk for asr.")
+ .setDefault(10)
+ .type(Integer.class)
+ .required(false);
+
+ parser
+ .addArgument("--mode")
+ .help("mode for asr.")
+ .setDefault("offline")
+ .type(String.class)
+ .required(false);
+ String srvIp = "";
+ String srvPort = "";
+ String wavPath = "";
+ int numThreads = 1;
+ String chunk_size = "";
+ int chunk_interval = 10;
+ String strmode = "offline";
+
+ try {
+ Namespace ns = parser.parseArgs(args);
+ srvIp = ns.get("host");
+ srvPort = ns.get("port");
+ wavPath = ns.get("audio_in");
+ numThreads = ns.get("num_threads");
+ chunk_size = ns.get("chunk_size");
+ chunk_interval = ns.get("chunk_interval");
+ strmode = ns.get("mode");
+ System.out.println(srvPort);
+
+ } catch (ArgumentParserException ex) {
+ ex.getParser().handleError(ex);
+ return;
+ }
+
+ FunasrWsClient.strChunkSize = chunk_size;
+ FunasrWsClient.chunkInterval = chunk_interval;
+ FunasrWsClient.wavPath = wavPath;
+ FunasrWsClient.mode = strmode;
+ System.out.println(
+ "serIp="
+ + srvIp
+ + ",srvPort="
+ + srvPort
+ + ",wavPath="
+ + wavPath
+ + ",strChunkSize"
+ + strChunkSize);
+
+ class ClientThread implements Runnable {
+
+ String srvIp;
+ String srvPort;
+
+ ClientThread(String srvIp, String srvPort, String wavPath) {
+ this.srvIp = srvIp;
+ this.srvPort = srvPort;
+ }
+
+ public void run() {
+ try {
+
+ int RATE = 16000;
+ String[] chunkList = strChunkSize.split(",");
+ int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
+ int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
+ int stride =
+ Integer.valueOf(
+ 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
+ System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
+ System.out.println("CHUNK:" + CHUNK);
+ System.out.println("stride:" + String.valueOf(stride));
+ FunasrWsClient.sendChunkSize = CHUNK * 2;
+
+ String wsAddress = "ws://" + srvIp + ":" + srvPort;
+
+ FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
+
+ c.connect();
+
+ System.out.println("wsAddress:" + wsAddress);
+ } catch (Exception e) {
+ e.printStackTrace();
+ System.out.println("e:" + e);
+ }
+ }
+ }
+ for (int i = 0; i < numThreads; i++) {
+ System.out.println("Thread1 is running...");
+ Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
+ t.start();
+ }
+ }
+}
diff --git a/funasr/runtime/java/Makefile b/funasr/runtime/java/Makefile
new file mode 100644
index 0000000..9a70ca5
--- /dev/null
+++ b/funasr/runtime/java/Makefile
@@ -0,0 +1,76 @@
+
+ENTRY_POINT = ./
+
+
+
+
+WEBSOCKET_DIR:= ./
+WEBSOCKET_FILES = \
+ $(WEBSOCKET_DIR)/FunasrWsClient.java \
+
+
+
+LIB_BUILD_DIR = ./lib
+
+
+
+
+JAVAC = javac
+
+BUILD_DIR = build
+
+
+RUNJFLAGS = -Dfile.encoding=utf-8
+
+
+vpath %.class $(BUILD_DIR)
+vpath %.java src
+
+
+
+
+rebuild: clean all
+
+.PHONY: clean run downjar
+
+downjar:
+ wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
+ wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
+ #wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
+ wget https://repo1.maven.org/maven2/org/java-websocket/Java-WebSocket/1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
+ wget https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/json-simple/json-simple-1.1.1.jar -P ./lib/
+ wget https://github.com/argparse4j/argparse4j/releases/download/argparse4j-0.9.0/argparse4j-0.9.0.jar -P ./lib/
+ rm -frv build
+ mkdir build
+clean:
+ rm -frv $(BUILD_DIR)/*
+ rm -frv $(LIB_BUILD_DIR)/*
+ mkdir -p $(BUILD_DIR)
+ mkdir -p ./lib
+
+
+
+
+
+
+runclient:
+ java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar $(RUNJFLAGS) websocket.FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
+
+
+
+buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
+
+
+%.class: %.java
+
+ $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar -d $(BUILD_DIR) -encoding UTF-8 $<
+
+packjar:
+ jar cvfe lib/funasrclient.jar . -C $(BUILD_DIR) .
+
+all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
+
+
+
+
+
diff --git a/funasr/runtime/java/readme.md b/funasr/runtime/java/readme.md
new file mode 100644
index 0000000..406a21a
--- /dev/null
+++ b/funasr/runtime/java/readme.md
@@ -0,0 +1,66 @@
+# Client for java websocket example
+
+
+
+## Building for Linux/Unix
+
+### install java environment
+```shell
+# in ubuntu
+apt-get install openjdk-11-jdk
+```
+
+
+
+### Build and run by make
+
+
+```shell
+cd funasr/runtime/java
+# download java lib
+make downjar
+# compile
+make buildwebsocket
+# run client
+make runclient
+
+```
+
+## Run java websocket client by shell
+
+```shell
+# full command refer to Makefile runclient
+usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
+ [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
+
+Where:
+ --host <string>
+ (required) server-ip
+
+ --port <int>
+ (required) port
+
+ --audio_in <string>
+ (required) the wav or pcm file path
+
+ --num_threads <int>
+ thread number for test
+
+ --mode
+ asr mode, support "offline" "online" "2pass"
+
+
+
+example:
+FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
+
+result json, example like:
+{"mode":"offline","text":"娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�","wav_name":"javatest"}
+```
+
+
+## Acknowledge
+1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
+2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/java-ws-client-support/funasr/runtime/java) for contributing the java websocket client example.
+
+
diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
index 962da0b..03c3a64 100644
--- a/funasr/runtime/onnxruntime/bin/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -12,5 +12,8 @@
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
+add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp")
+target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
+
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
index e18c27e..92c0525 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
@@ -84,11 +84,13 @@
long taking_micros = 0;
for(auto& txt_str : txt_list){
gettimeofday(&start, NULL);
- string result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
+ FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
- LOG(INFO)<<"Results: "<<result;
+ string msg = FunASRGetResult(result, 0);
+ LOG(INFO)<<"Results: "<<msg;
+ CTTransformerFreeResult(result);
}
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index d2692ce..ee05d75 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -59,13 +59,13 @@
if(result){
string msg = FunASRGetResult(result, 0);
- LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg.c_str();
+ LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
FunASRFreeResult(result);
}else{
- LOG(ERROR) << ("No return data!\n");
+ LOG(ERROR) << wav_ids[i] << (": No return data!\n");
}
}
{
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
new file mode 100644
index 0000000..c592616
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
@@ -0,0 +1,130 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+
+using namespace std;
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
+void splitString(vector<string>& strings, const string& org_string, const string& seq) {
+ string::size_type p1 = 0;
+ string::size_type p2 = org_string.find(seq);
+
+ while (p2 != string::npos) {
+ if (p2 == p1) {
+ ++p1;
+ p2 = org_string.find(seq, p1);
+ continue;
+ }
+ strings.push_back(org_string.substr(p1, p2 - p1));
+ p1 = p2 + seq.size();
+ p2 = org_string.find(seq, p1);
+ }
+
+ if (p1 != org_string.size()) {
+ strings.push_back(org_string.substr(p1));
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(txt_path);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(txt_path, TXT_PATH, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = 1;
+ FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num, PUNC_ONLINE);
+
+ if (!punc_hanlde)
+ {
+ LOG(ERROR) << "FunASR init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read txt_path
+ vector<string> txt_list;
+
+ if(model_path.find(TXT_PATH)!=model_path.end()){
+ ifstream in(model_path.at(TXT_PATH));
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ txt_list.emplace_back(line);
+ }
+ in.close();
+ }
+
+ long taking_micros = 0;
+ for(auto& txt_str : txt_list){
+ vector<string> vad_strs;
+ splitString(vad_strs, txt_str, "|");
+ string str_out;
+ FUNASR_RESULT result = nullptr;
+ gettimeofday(&start, NULL);
+ for(auto& vad_str:vad_strs){
+ result=CTTransformerInfer(punc_hanlde, vad_str.c_str(), RASR_NONE, NULL, PUNC_ONLINE, result);
+ if(result){
+ string msg = CTTransformerGetResult(result, 0);
+ str_out += msg;
+ LOG(INFO)<<"Online result: "<<msg;
+ }
+ }
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO)<<"Results: "<<str_out;
+ CTTransformerFreeResult(result);
+ }
+
+ LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+ CTTransformerUninit(punc_hanlde);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 7a6345b..0d3aee0 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -69,6 +69,7 @@
#define CANDIDATE_NUM 6
#define UNKNOW_INDEX 0
+#define NOTPUNC "_"
#define NOTPUNC_INDEX 1
#define COMMA_INDEX 2
#define PERIOD_INDEX 3
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index af430f7..98727bd 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,6 +46,11 @@
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
+typedef enum {
+ PUNC_OFFLINE=0,
+ PUNC_ONLINE=1,
+}PUNC_TYPE;
+
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// ASR
@@ -75,8 +80,10 @@
_FUNASRAPI const float FsmnVadGetRetSnippetTime(FUNASR_RESULT result);
// PUNC
-_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num);
-_FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
+_FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type=PUNC_OFFLINE, FUNASR_RESULT pre_result=nullptr);
+_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result);
_FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
diff --git a/funasr/runtime/onnxruntime/include/punc-model.h b/funasr/runtime/onnxruntime/include/punc-model.h
index da7ff60..4266eea 100644
--- a/funasr/runtime/onnxruntime/include/punc-model.h
+++ b/funasr/runtime/onnxruntime/include/punc-model.h
@@ -5,16 +5,17 @@
#include <string>
#include <map>
#include <vector>
+#include "funasrruntime.h"
namespace funasr {
class PuncModel {
public:
virtual ~PuncModel(){};
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
- virtual std::vector<int> Infer(std::vector<int32_t> input_data)=0;
- virtual std::string AddPunc(const char* sz_input)=0;
+ virtual std::string AddPunc(const char* sz_input){return "";};
+ virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){return "";};
};
-PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
+PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index d0882c6..b74c1c1 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -14,6 +14,11 @@
float snippet_time;
}FUNASR_VAD_RESULT;
+typedef struct
+{
+ string msg;
+ vector<string> arr_cache;
+}FUNASR_PUNC_RESULT;
#ifdef _WIN32
#include <codecvt>
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp b/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
new file mode 100644
index 0000000..191cda8
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
@@ -0,0 +1,283 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
+
+namespace funasr {
+CTTransformerOnline::CTTransformerOnline()
+:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
+{
+}
+
+void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
+ session_options.SetIntraOpNumThreads(thread_num);
+ session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+ session_options.DisableCpuMemArena();
+
+ try{
+ m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << punc_model;
+ }
+ catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load punc onnx model: " << e.what();
+ exit(0);
+ }
+ // read inputnames outputnames
+ string strName;
+ GetInputName(m_session.get(), strName);
+ m_strInputNames.push_back(strName.c_str());
+ GetInputName(m_session.get(), strName, 1);
+ m_strInputNames.push_back(strName);
+ GetInputName(m_session.get(), strName, 2);
+ m_strInputNames.push_back(strName);
+ GetInputName(m_session.get(), strName, 3);
+ m_strInputNames.push_back(strName);
+
+ GetOutputName(m_session.get(), strName);
+ m_strOutputNames.push_back(strName);
+
+ for (auto& item : m_strInputNames)
+ m_szInputNames.push_back(item.c_str());
+ for (auto& item : m_strOutputNames)
+ m_szOutputNames.push_back(item.c_str());
+
+ m_tokenizer.OpenYaml(punc_config.c_str());
+}
+
+CTTransformerOnline::~CTTransformerOnline()
+{
+}
+
+string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache)
+{
+ string strResult;
+ vector<string> strOut;
+ vector<int> InputData;
+ string strText; //full_text
+ strText = accumulate(arr_cache.begin(), arr_cache.end(), strText);
+ strText += sz_input; // full_text = precache + text
+ m_tokenizer.Tokenize(strText.c_str(), strOut, InputData);
+
+ int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
+ int nCurBatch = -1;
+ int nSentEnd = -1, nLastCommaIndex = -1;
+ vector<int32_t> RemainIDs; //
+ vector<string> RemainStr; //
+ vector<int> new_mini_sentence_punc; // sentence_punc_list = []
+ vector<string> sentenceOut; // sentenceOut
+ vector<string> sentence_punc_list,sentence_words_list,sentence_punc_list_out; // sentence_words_list = []
+
+ int nSkipNum = 0;
+ int nDiff = 0;
+ for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
+ {
+ nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
+ vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
+ vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
+ InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
+ InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
+
+ auto Punction = Infer(InputIDs, arr_cache.size());
+ nCurBatch = i / TOKEN_LEN;
+ if (nCurBatch < nTotalBatch - 1) // not the last minisetence
+ {
+ nSentEnd = -1;
+ nLastCommaIndex = -1;
+ for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
+ {
+ if (m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(PERIOD_INDEX) || m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(QUESTION_INDEX))
+ {
+ nSentEnd = nIndex;
+ break;
+ }
+ if (nLastCommaIndex < 0 && m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(COMMA_INDEX))
+ {
+ nLastCommaIndex = nIndex;
+ }
+ }
+ if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0)
+ {
+ nSentEnd = nLastCommaIndex;
+ Punction[nSentEnd] = PERIOD_INDEX;
+ }
+ RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end());
+ RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end());
+ InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence
+ Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1);
+ }
+
+ for (auto& item : Punction)
+ {
+ sentence_punc_list.push_back(m_tokenizer.Id2Punc(item));
+ }
+
+ sentence_words_list.insert(sentence_words_list.end(), InputStr.begin(), InputStr.end());
+
+ new_mini_sentence_punc.insert(new_mini_sentence_punc.end(), Punction.begin(), Punction.end());
+ }
+ vector<string> WordWithPunc;
+ for (int i = 0; i < sentence_words_list.size(); i++) // for i in range(0, len(sentence_words_list)):
+ {
+ if (i > 0 && !(sentence_words_list[i][0] & 0x80) && (i + 1) < sentence_words_list.size() && !(sentence_words_list[i + 1][0] & 0x80))
+ {
+ sentence_words_list[i] = sentence_words_list[i] + " ";
+ }
+ if (nSkipNum < arr_cache.size()) // if skip_num < len(cache):
+ nSkipNum++;
+ else
+ WordWithPunc.push_back(sentence_words_list[i]);
+
+ if (nSkipNum >= arr_cache.size())
+ {
+ sentence_punc_list_out.push_back(sentence_punc_list[i]);
+ if (sentence_punc_list[i] != NOTPUNC)
+ {
+ WordWithPunc.push_back(sentence_punc_list[i]);
+ }
+ }
+ }
+
+ sentenceOut.insert(sentenceOut.end(), WordWithPunc.begin(), WordWithPunc.end()); //
+ nSentEnd = -1;
+ for (int i = sentence_punc_list.size() - 2; i > 0; i--)
+ {
+ if (new_mini_sentence_punc[i] == PERIOD_INDEX || new_mini_sentence_punc[i] == QUESTION_INDEX)
+ {
+ nSentEnd = i;
+ break;
+ }
+ }
+ arr_cache.assign(sentence_words_list.begin() + nSentEnd + 1, sentence_words_list.end());
+
+ if (sentenceOut.size() > 0 && m_tokenizer.IsPunc(sentenceOut[sentenceOut.size() - 1]))
+ {
+ sentenceOut.assign(sentenceOut.begin(), sentenceOut.end() - 1);
+ sentence_punc_list_out[sentence_punc_list_out.size() - 1] = m_tokenizer.Id2Punc(NOTPUNC_INDEX);
+ }
+ return accumulate(sentenceOut.begin(), sentenceOut.end(), string(""));
+}
+
+vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSize)
+{
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+ vector<int> punction;
+ std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
+ Ort::Value onnx_input = Ort::Value::CreateTensor(
+ m_memoryInfo,
+ input_data.data(),
+ input_data.size() * sizeof(int32_t),
+ input_shape_.data(),
+ input_shape_.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
+
+ std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() };
+ std::array<int64_t,1> text_lengths_dim{ 1 };
+ Ort::Value onnx_text_lengths = Ort::Value::CreateTensor<int32_t>(
+ m_memoryInfo,
+ text_lengths.data(),
+ text_lengths.size(),
+ text_lengths_dim.data(),
+ text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
+
+ //vad_mask
+ vector<float> arVadMask,arSubMask;
+ int nTextLength = input_data.size();
+
+ VadMask(nTextLength, nCacheSize, arVadMask);
+ Triangle(nTextLength, arSubMask);
+ std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength };
+ Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ arVadMask.data(),
+ arVadMask.size(), // * sizeof(float),
+ VadMask_Dim.data(),
+ VadMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
+ //sub_masks
+
+ std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength };
+ Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ arSubMask.data(),
+ arSubMask.size() ,
+ SubMask_Dim.data(),
+ SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
+
+ std::vector<Ort::Value> input_onnx;
+ input_onnx.emplace_back(std::move(onnx_input));
+ input_onnx.emplace_back(std::move(onnx_text_lengths));
+ input_onnx.emplace_back(std::move(onnx_vad_mask));
+ input_onnx.emplace_back(std::move(onnx_sub_mask));
+
+ try {
+ auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
+ std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
+
+ int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
+ float * floatData = outputTensor[0].GetTensorMutableData<float>();
+
+ for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
+ {
+ int index = Argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
+ punction.push_back(index);
+ }
+ }
+ catch (std::exception const &e)
+ {
+ LOG(ERROR) << "Error when run punc onnx forword: " << (e.what());
+ exit(0);
+ }
+ return punction;
+}
+
+void CTTransformerOnline::VadMask(int nSize, int vad_pos, vector<float>& Result)
+{
+ Result.resize(0);
+ Result.assign(nSize * nSize, 1);
+ if (vad_pos <= 0 || vad_pos >= nSize)
+ {
+ return;
+ }
+ for (int i = 0; i < vad_pos-1; i++)
+ {
+ for (int j = vad_pos; j < nSize; j++)
+ {
+ Result[i * nSize + j] = 0.0f;
+ }
+ }
+}
+
+void CTTransformerOnline::Triangle(int text_length, vector<float>& Result)
+{
+ Result.resize(0);
+ Result.assign(text_length * text_length,1); // generate a zeros: text_length x text_length
+
+ for (int i = 0; i < text_length; i++) // rows
+ {
+ for (int j = i+1; j<text_length; j++) //cols
+ {
+ Result[i * text_length + j] = 0.0f;
+ }
+
+ }
+ //Transport(Result, text_length, text_length);
+}
+
+void CTTransformerOnline::Transport(vector<float>& In,int nRows, int nCols)
+{
+ vector<float> Out;
+ Out.resize(nRows * nCols);
+ int i = 0;
+ for (int j = 0; j < nCols; j++) {
+ for (; i < nRows * nCols; i++) {
+ Out[i] = In[j + nCols * (i % nRows)];
+ if ((i + 1) % nRows == 0) {
+ i++;
+ break;
+ }
+ }
+ }
+ In = Out;
+}
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer-online.h b/funasr/runtime/onnxruntime/src/ct-transformer-online.h
new file mode 100644
index 0000000..5db183a
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/ct-transformer-online.h
@@ -0,0 +1,37 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#pragma once
+
+namespace funasr {
+class CTTransformerOnline : public PuncModel {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ * https://arxiv.org/pdf/2003.01309.pdf
+*/
+
+private:
+
+ CTokenizer m_tokenizer;
+ vector<string> m_strInputNames, m_strOutputNames;
+ vector<const char*> m_szInputNames;
+ vector<const char*> m_szOutputNames;
+
+ std::shared_ptr<Ort::Session> m_session;
+ Ort::Env env_;
+ Ort::SessionOptions session_options;
+public:
+
+ CTTransformerOnline();
+ void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
+ ~CTTransformerOnline();
+ vector<int> Infer(vector<int32_t> input_data, int nCacheSize);
+ string AddPunc(const char* sz_input, vector<string> &arr_cache);
+ void Transport(vector<float>& In, int nRows, int nCols);
+ void VadMask(int size, int vad_pos,vector<float>& Result);
+ void Triangle(int text_length, vector<float>& Result);
+};
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
index 58eec25..2ee4114 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -18,6 +18,7 @@
try{
m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << punc_model;
}
catch (std::exception const &e) {
LOG(ERROR) << "Error when load punc onnx model: " << e.what();
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index f504b39..82fdd70 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -23,9 +23,9 @@
return mm;
}
- _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num)
+ _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type)
{
- funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num);
+ funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num, type);
return mm;
}
@@ -51,6 +51,9 @@
int flag = 0;
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
@@ -90,6 +93,9 @@
int n_total = audio.GetQueueSize();
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, flag);
p_result->msg += msg;
@@ -114,6 +120,9 @@
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
p_result->snippet_time = audio.GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
vector<std::vector<int>> vad_segments;
audio.Split(vad_obj, vad_segments, input_finished);
@@ -143,6 +152,9 @@
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
p_result->snippet_time = audio.GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
vector<std::vector<int>> vad_segments;
audio.Split(vad_obj, vad_segments, true);
@@ -152,14 +164,28 @@
}
// APIs for PUNC Infer
- _FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback)
+ _FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type, FUNASR_RESULT pre_result)
{
funasr::PuncModel* punc_obj = (funasr::PuncModel*)handle;
if (!punc_obj)
return nullptr;
+
+ FUNASR_RESULT p_result = nullptr;
+ if (type==PUNC_OFFLINE){
+ p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT;
+ ((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence);
+ }else if(type==PUNC_ONLINE){
+ if (!pre_result)
+ p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT;
+ else
+ p_result = pre_result;
+ ((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence, ((funasr::FUNASR_PUNC_RESULT*)p_result)->arr_cache);
+ }else{
+ LOG(ERROR) << "Wrong PUNC_TYPE";
+ exit(-1);
+ }
- string punc_res = punc_obj->AddPunc(sz_sentence);
- return punc_res;
+ return p_result;
}
// APIs for Offline-stream Infer
@@ -172,6 +198,11 @@
funasr::Audio audio(1);
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
if(offline_stream->UseVad()){
audio.Split(offline_stream);
}
@@ -179,8 +210,7 @@
float* buff;
int len;
int flag = 0;
- funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
- p_result->snippet_time = audio.GetTimeLen();
+
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
@@ -216,6 +246,11 @@
LOG(ERROR)<<"Wrong wav extension";
exit(-1);
}
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio.GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
if(offline_stream->UseVad()){
audio.Split(offline_stream);
}
@@ -225,8 +260,6 @@
int flag = 0;
int n_step = 0;
int n_total = audio.GetQueueSize();
- funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
- p_result->snippet_time = audio.GetTimeLen();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
p_result->msg+= msg;
@@ -277,6 +310,15 @@
return p_result->msg.c_str();
}
+ _FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
+ {
+ funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
+ if(!p_result)
+ return nullptr;
+
+ return p_result->msg.c_str();
+ }
+
_FUNASRAPI vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
@@ -295,6 +337,14 @@
}
}
+ _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result)
+ {
+ if (result)
+ {
+ delete (funasr::FUNASR_PUNC_RESULT*)result;
+ }
+ }
+
_FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result)
{
funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
diff --git a/funasr/runtime/onnxruntime/src/offline-stream.cpp b/funasr/runtime/onnxruntime/src/offline-stream.cpp
index 8170129..d96cf27 100644
--- a/funasr/runtime/onnxruntime/src/offline-stream.cpp
+++ b/funasr/runtime/onnxruntime/src/offline-stream.cpp
@@ -1,11 +1,11 @@
#include "precomp.h"
+#include <unistd.h>
namespace funasr {
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
{
// VAD model
if(model_path.find(VAD_DIR) != model_path.end()){
- use_vad = true;
string vad_model_path;
string vad_cmvn_path;
string vad_config_path;
@@ -16,8 +16,16 @@
}
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
- vad_handle = make_unique<FsmnVad>();
- vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ if (access(vad_model_path.c_str(), F_OK) != 0 ||
+ access(vad_cmvn_path.c_str(), F_OK) != 0 ||
+ access(vad_config_path.c_str(), F_OK) != 0 )
+ {
+ LOG(INFO) << "VAD model file is not exist, skip load vad model.";
+ }else{
+ vad_handle = make_unique<FsmnVad>();
+ vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ use_vad = true;
+ }
}
// AM model
@@ -39,7 +47,6 @@
// PUNC model
if(model_path.find(PUNC_DIR) != model_path.end()){
- use_punc = true;
string punc_model_path;
string punc_config_path;
@@ -49,8 +56,15 @@
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
- punc_handle = make_unique<CTTransformer>();
- punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ if (access(punc_model_path.c_str(), F_OK) != 0 ||
+ access(punc_config_path.c_str(), F_OK) != 0 )
+ {
+ LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
+ }else{
+ punc_handle = make_unique<CTTransformer>();
+ punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ use_punc = true;
+ }
}
}
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 1957a12..b605fff 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -33,6 +33,7 @@
try {
m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
exit(0);
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 838dddc..26ed2c5 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -36,6 +36,7 @@
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
+#include "ct-transformer-online.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "fsmn-vad-online.h"
diff --git a/funasr/runtime/onnxruntime/src/punc-model.cpp b/funasr/runtime/onnxruntime/src/punc-model.cpp
index 52ba0df..54b8d6a 100644
--- a/funasr/runtime/onnxruntime/src/punc-model.cpp
+++ b/funasr/runtime/onnxruntime/src/punc-model.cpp
@@ -1,11 +1,17 @@
#include "precomp.h"
namespace funasr {
-PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num)
+PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type)
{
PuncModel *mm;
- mm = new CTTransformer();
-
+ if (type==PUNC_OFFLINE){
+ mm = new CTTransformer();
+ }else if(type==PUNC_ONLINE){
+ mm = new CTTransformerOnline();
+ }else{
+ LOG(ERROR) << "Wrong PUNC TYPE";
+ exit(-1);
+ }
string punc_model_path;
string punc_config_path;
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index a8f6301..cd3f027 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -142,6 +142,14 @@
return result;
}
+bool CTokenizer::IsPunc(string& Punc)
+{
+ if (m_punc2id.find(Punc) != m_punc2id.end())
+ return true;
+ else
+ return false;
+}
+
vector<string> CTokenizer::SplitChineseString(const string & str_info)
{
vector<string> list;
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.h b/funasr/runtime/onnxruntime/src/tokenizer.h
index 419791b..3b1d1c5 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.h
+++ b/funasr/runtime/onnxruntime/src/tokenizer.h
@@ -30,7 +30,7 @@
vector<string> SplitChineseString(const string& str_info);
void StrSplit(const string& str, const char split, vector<string>& res);
void Tokenize(const char* str_info, vector<string>& str_out, vector<int>& id_out);
-
+ bool IsPunc(string& Punc);
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index 65af8b6..70553df 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -110,17 +110,16 @@
else {
// pre word is chinese
if (!is_pre_english) {
- word[0] = word[0] - 32;
+ // word[0] = word[0] - 32;
words.push_back(word);
pre_english_len = word.size();
-
}
// pre word is english word
else {
// single letter turn to upper case
- if (word.size() == 1) {
- word[0] = word[0] - 32;
- }
+ // if (word.size() == 1) {
+ // word[0] = word[0] - 32;
+ // }
if (pre_english_len > 1) {
words.push_back(" ");
diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/frontend.py b/funasr/runtime/python/libtorch/funasr_torch/utils/frontend.py
index 11a8644..fe39955 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/utils/frontend.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/utils/frontend.py
@@ -3,7 +3,6 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
-from typeguard import check_argument_types
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
@@ -28,7 +27,6 @@
dither: float = 1.0,
**kwargs,
) -> None:
- check_argument_types()
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs
diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
index 86e78bc..913ddc1 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -9,7 +9,6 @@
import numpy as np
import yaml
-from typeguard import check_argument_types
import warnings
@@ -21,7 +20,6 @@
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
- check_argument_types()
self.token_list = token_list
self.unk_symbol = token_list[-1]
@@ -51,7 +49,6 @@
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
- check_argument_types()
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)
diff --git a/funasr/runtime/python/libtorch/setup.py b/funasr/runtime/python/libtorch/setup.py
index fd8b151..4b20c0b 100644
--- a/funasr/runtime/python/libtorch/setup.py
+++ b/funasr/runtime/python/libtorch/setup.py
@@ -25,10 +25,13 @@
long_description=get_readme(),
long_description_content_type='text/markdown',
include_package_data=True,
- install_requires=["librosa", "onnxruntime>=1.7.0",
- "scipy", "numpy>=1.19.3",
- "typeguard", "kaldi-native-fbank",
- "PyYAML>=5.1.2", "torch-quant >= 0.4.0"],
+ install_requires=["librosa",
+ "onnxruntime>=1.7.0",
+ "scipy",
+ "numpy>=1.19.3",
+ "kaldi-native-fbank",
+ "PyYAML>=5.1.2",
+ "torch-quant >= 0.4.0"],
packages=find_packages(include=["torch_paraformer*"]),
keywords=[
'funasr, paraformer, funasr_torch'
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
index 5478236..ded04b6 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
@@ -4,7 +4,6 @@
import copy
import numpy as np
-from typeguard import check_argument_types
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
@@ -29,7 +28,6 @@
dither: float = 1.0,
**kwargs,
) -> None:
- check_argument_types()
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index dcee425..9284943 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -10,7 +10,6 @@
import yaml
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
SessionOptions, get_available_providers, get_device)
-from typeguard import check_argument_types
import warnings
@@ -22,7 +21,6 @@
class TokenIDConverter():
def __init__(self, token_list: Union[List, str],
):
- check_argument_types()
self.token_list = token_list
self.unk_symbol = token_list[-1]
@@ -52,7 +50,6 @@
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
- check_argument_types()
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 64e363f..a6f6828 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.1.0'
+VERSION_NUM = '0.1.1'
setuptools.setup(
name=MODULE_NAME,
@@ -31,7 +31,6 @@
"onnxruntime>=1.7.0",
"scipy",
"numpy>=1.19.3",
- "typeguard",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"funasr",
diff --git a/funasr/runtime/python/websocket/parse_args.py b/funasr/runtime/python/websocket/parse_args.py
index 82d9c90..ffecff7 100644
--- a/funasr/runtime/python/websocket/parse_args.py
+++ b/funasr/runtime/python/websocket/parse_args.py
@@ -33,7 +33,7 @@
help="0 for cpu, 1 for gpu")
parser.add_argument("--ncpu",
type=int,
- default=1,
+ default=4,
help="cpu cores")
parser.add_argument("--certfile",
type=str,
diff --git a/funasr/runtime/python/websocket/wss_client_asr.py b/funasr/runtime/python/websocket/wss_client_asr.py
index 586e0a4..2ea8a16 100644
--- a/funasr/runtime/python/websocket/wss_client_asr.py
+++ b/funasr/runtime/python/websocket/wss_client_asr.py
@@ -1,7 +1,7 @@
# -*- encoding: utf-8 -*-
import os
import time
-import websockets,ssl
+import websockets, ssl
import asyncio
# import threading
import argparse
@@ -12,6 +12,7 @@
import logging
+SUPPORT_AUDIO_TYPE_SETS = ['.wav', '.pcm']
logging.basicConfig(level=logging.ERROR)
parser = argparse.ArgumentParser()
@@ -53,7 +54,7 @@
type=str,
default=None,
help="output_dir")
-
+
parser.add_argument("--ssl",
type=int,
default=1,
@@ -68,22 +69,25 @@
print(args)
# voices = asyncio.Queue()
from queue import Queue
-voices = Queue()
+voices = Queue()
+offline_msg_done=False
+
ibest_writer = None
if args.output_dir is not None:
writer = DatadirWriter(args.output_dir)
ibest_writer = writer[f"1best_recog"]
+
async def record_microphone():
is_finished = False
import pyaudio
- #print("2")
- global voices
+ # print("2")
+ global voices
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
- chunk_size = 60*args.chunk_size[1]/args.chunk_interval
+ chunk_size = 60 * args.chunk_size[1] / args.chunk_interval
CHUNK = int(RATE / 1000 * chunk_size)
p = pyaudio.PyAudio()
@@ -94,19 +98,16 @@
input=True,
frames_per_buffer=CHUNK)
- message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
+ message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
+ "wav_name": "microphone", "is_speaking": True})
voices.put(message)
while True:
-
data = stream.read(CHUNK)
- message = data
-
+ message = data
voices.put(message)
-
await asyncio.sleep(0.005)
-async def record_from_scp(chunk_begin,chunk_size):
- import wave
+async def record_from_scp(chunk_begin, chunk_size):
global voices
is_finished = False
if args.audio_in.endswith(".scp"):
@@ -114,91 +115,98 @@
wavs = f_scp.readlines()
else:
wavs = [args.audio_in]
- if chunk_size>0:
- wavs=wavs[chunk_begin:chunk_begin+chunk_size]
+ if chunk_size > 0:
+ wavs = wavs[chunk_begin:chunk_begin + chunk_size]
for wav in wavs:
wav_splits = wav.strip().split()
+
wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
-
- # bytes_f = open(wav_path, "rb")
- # bytes_data = bytes_f.read()
- with wave.open(wav_path, "rb") as wav_file:
- params = wav_file.getparams()
- # header_length = wav_file.getheaders()[0][1]
- # wav_file.setpos(header_length)
- frames = wav_file.readframes(wav_file.getnframes())
+ if not len(wav_path.strip())>0:
+ continue
+ if wav_path.endswith(".pcm"):
+ with open(wav_path, "rb") as f:
+ audio_bytes = f.read()
+ elif wav_path.endswith(".wav"):
+ import wave
+ with wave.open(wav_path, "rb") as wav_file:
+ params = wav_file.getparams()
+ frames = wav_file.readframes(wav_file.getnframes())
+ audio_bytes = bytes(frames)
+ else:
+ raise NotImplementedError(
+ f'Not supported audio type')
- audio_bytes = bytes(frames)
# stride = int(args.chunk_size/1000*16000*2)
- stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
- chunk_num = (len(audio_bytes)-1)//stride + 1
+ stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * 16000 * 2)
+ chunk_num = (len(audio_bytes) - 1) // stride + 1
# print(stride)
-
+
# send first time
- message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
- voices.put(message)
+ message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
+ "wav_name": wav_name, "is_speaking": True})
+ #voices.put(message)
+ await websocket.send(message)
is_speaking = True
for i in range(chunk_num):
- beg = i*stride
- data = audio_bytes[beg:beg+stride]
- message = data
- voices.put(message)
- if i == chunk_num-1:
+ beg = i * stride
+ data = audio_bytes[beg:beg + stride]
+ message = data
+ #voices.put(message)
+ await websocket.send(message)
+ if i == chunk_num - 1:
is_speaking = False
message = json.dumps({"is_speaking": is_speaking})
- voices.put(message)
- # print("data_chunk: ", len(data_chunk))
- # print(voices.qsize())
- sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
+ #voices.put(message)
+ await websocket.send(message)
+
+ sleep_duration = 0.001 if args.send_without_sleep else 60 * args.chunk_size[1] / args.chunk_interval / 1000
await asyncio.sleep(sleep_duration)
+ # when all data sent, we need to close websocket
+ while not voices.empty():
+ await asyncio.sleep(1)
+ await asyncio.sleep(3)
+ # offline model need to wait for message recved
+
+ if args.mode=="offline":
+ global offline_msg_done
+ while not offline_msg_done:
+ await asyncio.sleep(1)
+
+ await websocket.close()
+
+
+
-
-async def ws_send():
- global voices
- global websocket
- print("started to sending data!")
- while True:
- while not voices.empty():
- data = voices.get()
- voices.task_done()
- try:
- await websocket.send(data)
- except Exception as e:
- print('Exception occurred:', e)
- traceback.print_exc()
- exit(0)
- await asyncio.sleep(0.005)
- await asyncio.sleep(0.005)
-
-
-
+
+
async def message(id):
- global websocket
+ global websocket,voices,offline_msg_done
text_print = ""
text_print_2pass_online = ""
text_print_2pass_offline = ""
- while True:
- try:
+ try:
+ while True:
+
meg = await websocket.recv()
meg = json.loads(meg)
wav_name = meg.get("wav_name", "demo")
- # print(wav_name)
text = meg["text"]
if ibest_writer is not None:
ibest_writer["text"][wav_name] = text
-
+
if meg["mode"] == "online":
text_print += "{}".format(text)
text_print = text_print[-args.words_max_print:]
os.system('clear')
- print("\rpid"+str(id)+": "+text_print)
- elif meg["mode"] == "online":
+ print("\rpid" + str(id) + ": " + text_print)
+ elif meg["mode"] == "offline":
text_print += "{}".format(text)
text_print = text_print[-args.words_max_print:]
os.system('clear')
- print("\rpid"+str(id)+": "+text_print)
+ print("\rpid" + str(id) + ": " + text_print)
+ offline_msg_done=True
else:
if meg["mode"] == "2pass-online":
text_print_2pass_online += "{}".format(text)
@@ -211,10 +219,12 @@
os.system('clear')
print("\rpid" + str(id) + ": " + text_print)
- except Exception as e:
+ except Exception as e:
print("Exception:", e)
- traceback.print_exc()
- exit(0)
+ #traceback.print_exc()
+ #await websocket.close()
+
+
async def print_messge():
global websocket
@@ -225,72 +235,87 @@
print(meg)
except Exception as e:
print("Exception:", e)
- traceback.print_exc()
+ #traceback.print_exc()
exit(0)
-async def ws_client(id,chunk_begin,chunk_size):
- global websocket
- if args.ssl==1:
- ssl_context = ssl.SSLContext()
- ssl_context.check_hostname = False
- ssl_context.verify_mode = ssl.CERT_NONE
- uri = "wss://{}:{}".format(args.host, args.port)
+async def ws_client(id, chunk_begin, chunk_size):
+ if args.audio_in is None:
+ chunk_begin=0
+ chunk_size=1
+ global websocket,voices,offline_msg_done
+
+ for i in range(chunk_begin,chunk_begin+chunk_size):
+ offline_msg_done=False
+ voices = Queue()
+ if args.ssl == 1:
+ ssl_context = ssl.SSLContext()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ uri = "wss://{}:{}".format(args.host, args.port)
else:
- uri = "ws://{}:{}".format(args.host, args.port)
- ssl_context=None
- print("connect to",uri)
- async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
+ uri = "ws://{}:{}".format(args.host, args.port)
+ ssl_context = None
+ print("connect to", uri)
+ async with websockets.connect(uri, subprotocols=["binary"], ping_interval=None, ssl=ssl_context) as websocket:
if args.audio_in is not None:
- task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
+ task = asyncio.create_task(record_from_scp(i, 1))
else:
task = asyncio.create_task(record_microphone())
- task2 = asyncio.create_task(ws_send())
- task3 = asyncio.create_task(message(id))
- await asyncio.gather(task, task2, task3)
+ #task2 = asyncio.create_task(ws_send())
+ task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
+ await asyncio.gather(task, task3)
+ exit(0)
+
-def one_thread(id,chunk_begin,chunk_size):
- asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
- asyncio.get_event_loop().run_forever()
-
+def one_thread(id, chunk_begin, chunk_size):
+ asyncio.get_event_loop().run_until_complete(ws_client(id, chunk_begin, chunk_size))
+ asyncio.get_event_loop().run_forever()
if __name__ == '__main__':
- # for microphone
- if args.audio_in is None:
- p = Process(target=one_thread,args=(0, 0, 0))
- p.start()
- p.join()
- print('end')
- else:
- # calculate the number of wavs for each preocess
- if args.audio_in.endswith(".scp"):
- f_scp = open(args.audio_in)
- wavs = f_scp.readlines()
- else:
- wavs = [args.audio_in]
- total_len=len(wavs)
- if total_len>=args.test_thread_num:
- chunk_size=int((total_len)/args.test_thread_num)
- remain_wavs=total_len-chunk_size*args.test_thread_num
- else:
- chunk_size=1
- remain_wavs=0
+ # for microphone
+ if args.audio_in is None:
+ p = Process(target=one_thread, args=(0, 0, 0))
+ p.start()
+ p.join()
+ print('end')
+ else:
+ # calculate the number of wavs for each preocess
+ if args.audio_in.endswith(".scp"):
+ f_scp = open(args.audio_in)
+ wavs = f_scp.readlines()
+ else:
+ wavs = [args.audio_in]
+ for wav in wavs:
+ wav_splits = wav.strip().split()
+ wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
+ wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
+ audio_type = os.path.splitext(wav_path)[-1].lower()
+ if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
+ raise NotImplementedError(
+ f'Not supported audio type: {audio_type}')
- process_list = []
- chunk_begin=0
- for i in range(args.test_thread_num):
- now_chunk_size= chunk_size
- if remain_wavs>0:
- now_chunk_size=chunk_size+1
- remain_wavs=remain_wavs-1
- # process i handle wavs at chunk_begin and size of now_chunk_size
- p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
- chunk_begin=chunk_begin+now_chunk_size
- p.start()
- process_list.append(p)
+ total_len = len(wavs)
+ if total_len >= args.test_thread_num:
+ chunk_size = int(total_len / args.test_thread_num)
+ remain_wavs = total_len - chunk_size * args.test_thread_num
+ else:
+ chunk_size = 1
+ remain_wavs = 0
- for i in process_list:
- p.join()
+ process_list = []
+ chunk_begin = 0
+ for i in range(args.test_thread_num):
+ now_chunk_size = chunk_size
+ if remain_wavs > 0:
+ now_chunk_size = chunk_size + 1
+ remain_wavs = remain_wavs - 1
+ # process i handle wavs at chunk_begin and size of now_chunk_size
+ p = Process(target=one_thread, args=(i, chunk_begin, now_chunk_size))
+ chunk_begin = chunk_begin + now_chunk_size
+ p.start()
+ process_list.append(p)
- print('end')
+ for i in process_list:
+ p.join()
-
+ print('end')
diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/wss_srv_asr.py
index 3810cd6..09f2305 100644
--- a/funasr/runtime/python/websocket/wss_srv_asr.py
+++ b/funasr/runtime/python/websocket/wss_srv_asr.py
@@ -35,8 +35,6 @@
task=Tasks.voice_activity_detection,
model=args.vad_model,
model_revision=None,
- output_dir=None,
- batch_size=1,
mode='online',
ngpu=args.ngpu,
ncpu=args.ncpu,
@@ -69,9 +67,9 @@
websocket.param_dict_asr_online = {"cache": dict()}
websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
websocket.param_dict_asr_online["is_final"]=True
- audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
- inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
- inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
+ # audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
+ # inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+ # inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
await websocket.close()
diff --git a/funasr/runtime/readme.md b/funasr/runtime/readme.md
new file mode 100644
index 0000000..93a1547
--- /dev/null
+++ b/funasr/runtime/readme.md
@@ -0,0 +1,30 @@
+# FunASR runtime-SDK
+涓枃鏂囨。锛圼鐐瑰嚮姝ゅ](./readme_cn.md)锛�
+
+FunASR is a speech recognition framework developed by the Speech Lab of DAMO Academy, which integrates industrial-level models in the fields of speech endpoint detection, speech recognition, punctuation segmentation, and more.
+It has attracted many developers to participate in experiencing and developing. To solve the last mile of industrial landing and integrate models into business, we have developed the FunASR runtime-SDK. The SDK supports several service deployments, including:
+
+- File transcription service, Mandarin, CPU version, done
+- File transcription service, Mandarin, GPU version, in progress
+- File transcription service, English, in progress
+- Streaming speech recognition service, is in progress
+- and more.
+
+
+## File Transcription Service, Mandarin (CPU)
+
+Currently, the FunASR runtime-SDK-0.0.1 version supports the deployment of file transcription service, Mandarin (CPU version), with a complete speech recognition chain that can transcribe tens of hours of audio into punctuated text, and supports recognition for more than a hundred concurrent streams.
+
+To meet the needs of different users, we have prepared different tutorials with text and images for both novice and advanced developers.
+
+### Technical Principles
+
+The technical principles and documentation behind FunASR explain the underlying technology, recognition accuracy, computational efficiency, and core advantages of the framework, including convenience, high precision, high efficiency, and support for long audio chains. For detailed information, please refer to the documentation available by [docs](https://mp.weixin.qq.com/s?__biz=MzA3MTQ0NTUyMw==&tempkey=MTIyNF84d05USjMxSEpPdk5GZXBJUFNJNzY0bU1DTkxhV19mcWY4MTNWQTJSYXhUaFgxOWFHZTZKR0JzWC1JRmRCdUxCX2NoQXg0TzFpNmVJX2R1WjdrcC02N2FEcUc3MDhzVVhpNWQ5clU4QUdqNFdkdjFYb18xRjlZMmc5c3RDOTl0U0NiRkJLb05ZZ0RmRlVkVjFCZnpXNWFBVlRhbXVtdWs4bUMwSHZnfn4%3D&chksm=1f2c3254285bbb42bc8f76a82e9c5211518a0bb1ff8c357d085c1b78f675ef2311f3be6e282c#rd).
+
+### Deployment Tutorial
+
+The documentation mainly targets novice users who have no need for modifications or customization. It supports downloading model deployments from modelscope and also supports deploying models that users have fine-tuned. For detailed tutorials, please refer to [docs](docs/SDK_tutorial.md).
+
+### Advanced Development Guide
+
+The documentation mainly targets advanced developers who require modifications and customization of the service. It supports downloading model deployments from modelscope and also supports deploying models that users have fine-tuned. For detailed information, please refer to the documentation available by [docs](websocket/readme.md)
diff --git a/funasr/runtime/readme_cn.md b/funasr/runtime/readme_cn.md
new file mode 100644
index 0000000..3a76c08
--- /dev/null
+++ b/funasr/runtime/readme_cn.md
@@ -0,0 +1,31 @@
+# FunASR runtime-SDK
+
+English Version锛圼docs](./readme.md)锛�
+
+FunASR鏄敱杈炬懇闄㈣闊冲疄楠屽寮�婧愮殑涓�娆捐闊宠瘑鍒熀纭�妗嗘灦锛岄泦鎴愪簡璇煶绔偣妫�娴嬨�佽闊宠瘑鍒�佹爣鐐规柇鍙ョ瓑棰嗗煙鐨勫伐涓氱骇鍒ā鍨嬶紝鍚稿紩浜嗕紬澶氬紑鍙戣�呭弬涓庝綋楠屽拰寮�鍙戙�備负浜嗚В鍐冲伐涓氳惤鍦扮殑鏈�鍚庝竴鍏噷锛屽皢妯″瀷闆嗘垚鍒颁笟鍔′腑鍘伙紝鎴戜滑寮�鍙戜簡FunASR runtime-SDK銆�
+SDK 鏀寔浠ヤ笅鍑犵鏈嶅姟閮ㄧ讲锛�
+
+- 涓枃绂荤嚎鏂囦欢杞啓鏈嶅姟锛圕PU鐗堟湰锛夛紝宸插畬鎴�
+- 涓枃绂荤嚎鏂囦欢杞啓鏈嶅姟锛圙PU鐗堟湰锛夛紝杩涜涓�
+- 鑻辨枃绂荤嚎杞啓鏈嶅姟锛岃繘琛屼腑
+- 娴佸紡璇煶璇嗗埆鏈嶅姟锛岃繘琛屼腑
+- 銆傘�傘��
+
+
+## 涓枃绂荤嚎鏂囦欢杞啓鏈嶅姟閮ㄧ讲锛圕PU鐗堟湰锛�
+
+鐩墠FunASR runtime-SDK-0.0.1鐗堟湰宸叉敮鎸佷腑鏂囪闊崇绾挎枃浠舵湇鍔¢儴缃诧紙CPU鐗堟湰锛夛紝鎷ユ湁瀹屾暣鐨勮闊宠瘑鍒摼璺紝鍙互灏嗗嚑鍗佷釜灏忔椂鐨勯煶棰戣瘑鍒垚甯︽爣鐐圭殑鏂囧瓧锛岃�屼笖鏀寔涓婄櫨璺苟鍙戝悓鏃惰繘琛岃瘑鍒��
+
+涓轰簡鏀寔涓嶅悓鐢ㄦ埛鐨勯渶姹傦紝鎴戜滑鍒嗗埆閽堝灏忕櫧涓庨珮闃跺紑鍙戣�咃紝鍑嗗浜嗕笉鍚岀殑鍥炬枃鏁欑▼锛�
+
+### 鎶�鏈師鐞嗘彮绉�
+
+鏂囨。浠嬬粛浜嗚儗鍚庢妧鏈師鐞嗭紝璇嗗埆鍑嗙‘鐜囷紝璁$畻鏁堢巼绛夛紝浠ュ強鏍稿績浼樺娍浠嬬粛锛氫究鎹枫�侀珮绮惧害銆侀珮鏁堢巼銆侀暱闊抽閾捐矾锛岃缁嗘枃妗e弬鑰冿紙[鐐瑰嚮姝ゅ](https://mp.weixin.qq.com/s?__biz=MzA3MTQ0NTUyMw==&tempkey=MTIyNF84d05USjMxSEpPdk5GZXBJUFNJNzY0bU1DTkxhV19mcWY4MTNWQTJSYXhUaFgxOWFHZTZKR0JzWC1JRmRCdUxCX2NoQXg0TzFpNmVJX2R1WjdrcC02N2FEcUc3MDhzVVhpNWQ5clU4QUdqNFdkdjFYb18xRjlZMmc5c3RDOTl0U0NiRkJLb05ZZ0RmRlVkVjFCZnpXNWFBVlRhbXVtdWs4bUMwSHZnfn4%3D&chksm=1f2c3254285bbb42bc8f76a82e9c5211518a0bb1ff8c357d085c1b78f675ef2311f3be6e282c#rd)锛�
+
+### 渚挎嵎閮ㄧ讲鏁欑▼
+
+鏂囨。涓昏閽堝灏忕櫧鐢ㄦ埛涓庡垵绾у紑鍙戣�咃紝娌℃湁淇敼銆佸畾鍒堕渶姹傦紝鏀寔浠巑odelscope涓笅杞芥ā鍨嬮儴缃诧紝涔熸敮鎸佺敤鎴穎inetune鍚庣殑妯″瀷閮ㄧ讲锛岃缁嗘暀绋嬪弬鑰冿紙[鐐瑰嚮姝ゅ](./docs/SDK_tutorial_cn.md)锛�
+
+### 楂橀樁寮�鍙戞寚鍗�
+
+鏂囨。涓昏閽堝楂橀樁寮�鍙戣�咃紝闇�瑕佸鏈嶅姟杩涜淇敼涓庡畾鍒讹紝鏀寔浠巑odelscope涓笅杞芥ā鍨嬮儴缃诧紝涔熸敮鎸佺敤鎴穎inetune鍚庣殑妯″瀷閮ㄧ讲锛岃缁嗘枃妗e弬鑰冿紙[鐐瑰嚮姝ゅ](./docs/SDK_advanced_guide_cn.md)锛�
diff --git a/funasr/runtime/ssl_key/readme.md b/funasr/runtime/ssl_key/readme.md
index a5989e6..8a48dd3 100644
--- a/funasr/runtime/ssl_key/readme.md
+++ b/funasr/runtime/ssl_key/readme.md
@@ -3,7 +3,7 @@
```shell
### 1) Generate a private key
-openssl genrsa -des3 -out server.key 1024
+openssl genrsa -des3 -out server.key 2048
### 2) Generate a csr file
openssl req -new -key server.key -out server.csr
@@ -14,4 +14,4 @@
### 4) Generated a crt file, valid for 1 year
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
-```
\ No newline at end of file
+```
diff --git a/funasr/runtime/ssl_key/server.crt b/funasr/runtime/ssl_key/server.crt
index 808b73e..5a5079d 100644
--- a/funasr/runtime/ssl_key/server.crt
+++ b/funasr/runtime/ssl_key/server.crt
@@ -1,15 +1,21 @@
-----BEGIN CERTIFICATE-----
-MIICSDCCAbECFCObiVAMkMlCGmMDGDFx5Nx3XYvOMA0GCSqGSIb3DQEBCwUAMGMx
-CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5n
-MRAwDgYDVQQKDAdhbGliYWJhMQwwCgYDVQQLDANhc3IxEDAOBgNVBAMMB2FsaWJh
-YmEwHhcNMjMwNTEyMTQzNjAxWhcNMjQwNTExMTQzNjAxWjBjMQswCQYDVQQGEwJD
-TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEQMA4GA1UECgwH
-YWxpYmFiYTEMMAoGA1UECwwDYXNyMRAwDgYDVQQDDAdhbGliYWJhMIGfMA0GCSqG
-SIb3DQEBAQUAA4GNADCBiQKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETS
-NBrrRfjbBv6ucAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc
-38+VLo9UjjsoOeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdL
-mV/toQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAB5KNWF1XIIYD1geMsyT6/ZRnGNA
-dmeUyMcwYvIlQG3boSipNk/JI4W5fFOg1O2sAqflYHmwZfmasAQsC2e5bSzHZ+PB
-uMJhKYxfj81p175GumHTw5Lbp2CvFSLrnuVB0ThRdcCqEh1MDt0D3QBuBr/ZKgGS
-hXtozVCgkSJzX6uD
+MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
+CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
+MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
+bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
+DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
+EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
+aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
+ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
+qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
+f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
+EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
+Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
+LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
+AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
+D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
+dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
+Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
+7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
+Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
-----END CERTIFICATE-----
diff --git a/funasr/runtime/ssl_key/server.key b/funasr/runtime/ssl_key/server.key
index aac8b26..8efdcb8 100644
--- a/funasr/runtime/ssl_key/server.key
+++ b/funasr/runtime/ssl_key/server.key
@@ -1,15 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
-MIICXQIBAAKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETSNBrrRfjbBv6u
-cAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc38+VLo9Ujjso
-OeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdLmV/toQIDAQAB
-AoGARpA0pwygp+ZDWvh7kDLoZRitCK+BkZHiNHX1ZNeAU+Oh7FOw79u43ilqqXHq
-pxPEFYb7oVO8Kanhb4BlE32EmApBlvhd3SW07kn0dS7WVGsTvPFwKKpF88W8E+pc
-2i8At5tr2O1DZhvqNdIN7r8FRrGQ/Hpm3ItypUdz2lZnMwECQQD3dILOMJ84O2JE
-NxUwk8iOYefMJftQUO57Gm7XBVke/i3r9uajSqB2xmOvUaSyaHoJfx/mmfgfxYcD
-M+Re6mERAkEAyuaV5+eD82eG2I8PgxJ2p5SOb1x5F5qpb4KuKAlfHEkdolttMwN3
-7vl1ZWUZLVu2rHnUmvbYV2gkQO1os7/DkQJBAIDYfbN2xbC12vjB5ZqhmG/qspMt
-w6mSOlqG7OewtTLaDncq2/RySxMNQaJr1GHA3KpNMwMTcIq6gw472tFBIMECQF0z
-fjiASEROkcp4LI/ws0BXJPZSa+1DxgDK7mTFqUK88zfY91gvh6/mNt7UibQkJM0l
-SVvFd6ru03hflXC77YECQQDDQrB9ApwVOMGQw+pwbxn9p8tPYVi3oBiUfYgd1RDO
-uhcRgxv7gT4BSiyI4nFBMCYyI28azTLlUiJhMr9MNUpB
+MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
+vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
+5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
+jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
+ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
+NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
+MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
+vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
+jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
+eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
+C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
+6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
+jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
+qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
+aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
+zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
+xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
+9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
+3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
+P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
+7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
+cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
+dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
+cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
+W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
-----END RSA PRIVATE KEY-----
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
index 2f84bb8..c556daf 100644
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
+++ b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
@@ -109,7 +109,6 @@
lfr_n: int = 6,
dither: float = 1.0
) -> None:
- # check_argument_types()
self.fs = fs
self.window = window
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index 58ca972..513e48d 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -6,12 +6,10 @@
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
-
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
if(ENABLE_WEBSOCKET)
# cmake_policy(SET CMP0135 NEW)
-
include(FetchContent)
FetchContent_Declare(websocketpp
GIT_REPOSITORY https://github.com/zaphoyd/websocketpp.git
@@ -22,7 +20,6 @@
FetchContent_MakeAvailable(websocketpp)
include_directories(${PROJECT_SOURCE_DIR}/third_party/websocket)
-
FetchContent_Declare(asio
URL https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
@@ -38,8 +35,6 @@
FetchContent_MakeAvailable(json)
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
-
-
endif()
@@ -61,8 +56,8 @@
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
-add_executable(websocketmain "websocketmain.cpp" "websocketsrv.cpp")
-add_executable(websocketclient "websocketclient.cpp")
+add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
+add_executable(funasr-wss-client "funasr-wss-client.cpp")
-target_link_libraries(websocketclient PUBLIC funasr ssl crypto)
-target_link_libraries(websocketmain PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
new file mode 100644
index 0000000..eb94d14
--- /dev/null
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -0,0 +1,378 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// client for websocket, support multiple threads
+// ./funasr-wss-client --server-ip <string>
+// --port <string>
+// --wav-path <string>
+// [--thread-num <int>]
+// [--is-ssl <int>] [--]
+// [--version] [-h]
+// example:
+// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
+
+#define ASIO_STANDALONE 1
+#include <websocketpp/client.hpp>
+#include <websocketpp/common/thread.hpp>
+#include <websocketpp/config/asio_client.hpp>
+#include <fstream>
+#include <atomic>
+#include <glog/logging.h>
+
+#include "audio.h"
+#include "nlohmann/json.hpp"
+#include "tclap/CmdLine.h"
+
+/**
+ * Define a semi-cross platform helper method that waits/sleeps for a bit.
+ */
+void WaitABit() {
+ #ifdef WIN32
+ Sleep(1000);
+ #else
+ sleep(1);
+ #endif
+}
+std::atomic<int> wav_index(0);
+
+bool IsTargetFile(const std::string& filename, const std::string target) {
+ std::size_t pos = filename.find_last_of(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+ std::string extension = filename.substr(pos + 1);
+ return (extension == target);
+}
+
+typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context> context_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+context_ptr OnTlsInit(websocketpp::connection_hdl) {
+ context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+ asio::ssl::context::sslv23);
+
+ try {
+ ctx->set_options(
+ asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
+ asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
+
+ } catch (std::exception& e) {
+ LOG(ERROR) << e.what();
+ }
+ return ctx;
+}
+
+// template for tls or not config
+template <typename T>
+class WebsocketClient {
+ public:
+ // typedef websocketpp::client<T> client;
+ // typedef websocketpp::client<websocketpp::config::asio_tls_client>
+ // wss_client;
+ typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
+
+ WebsocketClient(int is_ssl) : m_open(false), m_done(false) {
+ // set up access channels to only log interesting things
+ m_client.clear_access_channels(websocketpp::log::alevel::all);
+ m_client.set_access_channels(websocketpp::log::alevel::connect);
+ m_client.set_access_channels(websocketpp::log::alevel::disconnect);
+ m_client.set_access_channels(websocketpp::log::alevel::app);
+
+ // Initialize the Asio transport policy
+ m_client.init_asio();
+
+ // Bind the handlers we are using
+ using websocketpp::lib::bind;
+ using websocketpp::lib::placeholders::_1;
+ m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
+ m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
+
+ m_client.set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+
+ m_client.set_fail_handler(bind(&WebsocketClient::on_fail, this, _1));
+ m_client.clear_access_channels(websocketpp::log::alevel::all);
+ }
+
+ void on_message(websocketpp::connection_hdl hdl, message_ptr msg) {
+ const std::string& payload = msg->get_payload();
+ switch (msg->get_opcode()) {
+ case websocketpp::frame::opcode::text:
+ total_num=total_num+1;
+ LOG(INFO)<<total_num<<",on_message = " << payload;
+ if((total_num+1)==wav_index)
+ {
+ websocketpp::lib::error_code ec;
+ m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
+ if (ec){
+ LOG(ERROR)<< "Error closing connection " << ec.message();
+ }
+ }
+ }
+ }
+
+ // This method will block until the connection is complete
+ void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
+ // Create a new connection to the given URI
+ websocketpp::lib::error_code ec;
+ typename websocketpp::client<T>::connection_ptr con =
+ m_client.get_connection(uri, ec);
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Get Connection Error: " + ec.message());
+ return;
+ }
+ // Grab a handle for this connection so we can talk to it in a thread
+ // safe manor after the event loop starts.
+ m_hdl = con->get_handle();
+
+ // Queue the connection. No DNS queries or network connections will be
+ // made until the io_service event loop is run.
+ m_client.connect(con);
+
+ // Create a thread to run the ASIO io_service event loop
+ websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
+ &m_client);
+ while(true){
+ int i = wav_index.fetch_add(1);
+ if (i >= wav_list.size()) {
+ break;
+ }
+ send_wav_data(wav_list[i], wav_ids[i]);
+ }
+ WaitABit();
+
+ asio_thread.join();
+
+ }
+
+ // The open handler will signal that we are ready to start sending data
+ void on_open(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection opened, starting data!");
+
+ scoped_lock guard(m_lock);
+ m_open = true;
+ }
+
+ // The close handler will signal that we should stop sending data
+ void on_close(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection closed, stopping data!");
+
+ scoped_lock guard(m_lock);
+ m_done = true;
+ }
+
+ // The fail handler will signal that we should stop sending data
+ void on_fail(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection failed, stopping data!");
+
+ scoped_lock guard(m_lock);
+ m_done = true;
+ }
+ // send wav to server
+ void send_wav_data(string wav_path, string wav_id) {
+ uint64_t count = 0;
+ std::stringstream val;
+
+ funasr::Audio audio(1);
+ int32_t sampling_rate = 16000;
+ if(IsTargetFile(wav_path.c_str(), "wav")){
+ int32_t sampling_rate = -1;
+ if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
+ return ;
+ }else if(IsTargetFile(wav_path.c_str(), "pcm")){
+ if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
+ return ;
+ }else{
+ printf("Wrong wav extension");
+ exit(-1);
+ }
+
+ float* buff;
+ int len;
+ int flag = 0;
+ bool wait = false;
+ while (1) {
+ {
+ scoped_lock guard(m_lock);
+ // If the connection has been closed, stop generating data
+ if (m_done) {
+ break;
+ }
+ // If the connection hasn't been opened yet wait a bit and retry
+ if (!m_open) {
+ wait = true;
+ } else {
+ break;
+ }
+ }
+ if (wait) {
+ // LOG(INFO) << "wait.." << m_open;
+ WaitABit();
+ continue;
+ }
+ }
+ websocketpp::lib::error_code ec;
+
+ nlohmann::json jsonbegin;
+ nlohmann::json chunk_size = nlohmann::json::array();
+ chunk_size.push_back(5);
+ chunk_size.push_back(0);
+ chunk_size.push_back(5);
+ jsonbegin["chunk_size"] = chunk_size;
+ jsonbegin["chunk_interval"] = 10;
+ jsonbegin["wav_name"] = wav_id;
+ jsonbegin["is_speaking"] = true;
+ m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+ ec);
+
+ // fetch wav data use asr engine api
+ while (audio.Fetch(buff, len, flag) > 0) {
+ short* iArray = new short[len];
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buff[i]*32768);
+ }
+
+ // send data to server
+ int offset = 0;
+ int block_size = 102400;
+ while(offset < len){
+ int send_block = 0;
+ if (offset + block_size <= len){
+ send_block = block_size;
+ }else{
+ send_block = len - offset;
+ }
+ m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ offset += send_block;
+ }
+
+ LOG(INFO) << "sended data len=" << len * sizeof(short);
+ // The most likely error that we will get is that the connection is
+ // not in the right state. Usually this means we tried to send a
+ // message to a connection that was closed or in the process of
+ // closing. While many errors here can be easily recovered from,
+ // in this simple example, we'll stop the data loop.
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ break;
+ }
+ delete[] iArray;
+ // WaitABit();
+ }
+ nlohmann::json jsonresult;
+ jsonresult["is_speaking"] = false;
+ m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
+ // WaitABit();
+ }
+ websocketpp::client<T> m_client;
+
+ private:
+ websocketpp::connection_hdl m_hdl;
+ websocketpp::lib::mutex m_lock;
+ bool m_open;
+ bool m_done;
+ int total_num=0;
+};
+
+int main(int argc, char* argv[]) {
+
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
+ TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
+ "127.0.0.1", "string");
+ TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
+ TCLAP::ValueArg<std::string> wav_path_("", "wav-path",
+ "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
+ true, "", "string");
+ TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num",
+ false, 1, "int");
+ TCLAP::ValueArg<int> is_ssl_(
+ "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
+ false, 1, "int");
+
+ cmd.add(server_ip_);
+ cmd.add(port_);
+ cmd.add(wav_path_);
+ cmd.add(thread_num_);
+ cmd.add(is_ssl_);
+ cmd.parse(argc, argv);
+
+ std::string server_ip = server_ip_.getValue();
+ std::string port = port_.getValue();
+ std::string wav_path = wav_path_.getValue();
+ int threads_num = thread_num_.getValue();
+ int is_ssl = is_ssl_.getValue();
+
+ std::vector<websocketpp::lib::thread> client_threads;
+ std::string uri = "";
+ if (is_ssl == 1) {
+ uri = "wss://" + server_ip + ":" + port;
+ } else {
+ uri = "ws://" + server_ip + ":" + port;
+ }
+
+ // read wav_path
+ std::vector<string> wav_list;
+ std::vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ if(IsTargetFile(wav_path, "wav") || IsTargetFile(wav_path, "pcm")){
+ wav_list.emplace_back(wav_path);
+ wav_ids.emplace_back(default_id);
+ }
+ else if(IsTargetFile(wav_path, "scp")){
+ ifstream in(wav_path);
+ if (!in.is_open()) {
+ printf("Failed to open scp file");
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ }else{
+ printf("Please check the wav extension!");
+ exit(-1);
+ }
+
+ for (size_t i = 0; i < threads_num; i++) {
+ client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
+ if (is_ssl == 1) {
+ WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+
+ c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+
+ c.run(uri, wav_list, wav_ids);
+ } else {
+ WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+
+ c.run(uri, wav_list, wav_ids);
+ }
+ });
+ }
+
+ for (auto& t : client_threads) {
+ t.join();
+ }
+}
\ No newline at end of file
diff --git a/funasr/runtime/websocket/funasr-wss-server.cpp b/funasr/runtime/websocket/funasr-wss-server.cpp
new file mode 100644
index 0000000..5061bba
--- /dev/null
+++ b/funasr/runtime/websocket/funasr-wss-server.cpp
@@ -0,0 +1,329 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// io server
+// Usage:funasr-wss-server [--model_thread_num <int>] [--decoder_thread_num <int>]
+// [--io_thread_num <int>] [--port <int>] [--listen_ip
+// <string>] [--punc-quant <string>] [--punc-dir <string>]
+// [--vad-quant <string>] [--vad-dir <string>] [--quantize
+// <string>] --model-dir <string> [--] [--version] [-h]
+#include "websocket-server.h"
+#include <unistd.h>
+
+using namespace std;
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
+ std::map<std::string, std::string>& model_path) {
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO) << key << " : " << value_arg.getValue();
+}
+int main(int argc, char* argv[]) {
+ try {
+
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
+ TCLAP::ValueArg<std::string> download_model_dir(
+ "", "download-model-dir",
+ "Download model from Modelscope to download_model_dir",
+ false, "/workspace/models", "string");
+ TCLAP::ValueArg<std::string> model_dir(
+ "", MODEL_DIR,
+ "default: /workspace/models/asr, the asr model path, which contains model_quant.onnx, config.yaml, am.mvn",
+ false, "/workspace/models/asr", "string");
+ TCLAP::ValueArg<std::string> model_revision(
+ "", "model-revision",
+ "ASR model revision",
+ false, "v1.2.1", "string");
+ TCLAP::ValueArg<std::string> quantize(
+ "", QUANTIZE,
+ "true (Default), load the model of model_quant.onnx in model_dir. If set "
+ "false, load the model of model.onnx in model_dir",
+ false, "true", "string");
+ TCLAP::ValueArg<std::string> vad_dir(
+ "", VAD_DIR,
+ "default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
+ false, "/workspace/models/vad", "string");
+ TCLAP::ValueArg<std::string> vad_revision(
+ "", "vad-revision",
+ "VAD model revision",
+ false, "v1.2.0", "string");
+ TCLAP::ValueArg<std::string> vad_quant(
+ "", VAD_QUANT,
+ "true (Default), load the model of model_quant.onnx in vad_dir. If set "
+ "false, load the model of model.onnx in vad_dir",
+ false, "true", "string");
+ TCLAP::ValueArg<std::string> punc_dir(
+ "", PUNC_DIR,
+ "default: /workspace/models/punc, the punc model path, which contains model_quant.onnx, punc.yaml",
+ false, "/workspace/models/punc",
+ "string");
+ TCLAP::ValueArg<std::string> punc_revision(
+ "", "punc-revision",
+ "PUNC model revision",
+ false, "v1.1.7", "string");
+ TCLAP::ValueArg<std::string> punc_quant(
+ "", PUNC_QUANT,
+ "true (Default), load the model of model_quant.onnx in punc_dir. If set "
+ "false, load the model of model.onnx in punc_dir",
+ false, "true", "string");
+
+ TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
+ "0.0.0.0", "string");
+ TCLAP::ValueArg<int> port("", "port", "port", false, 10095, "int");
+ TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
+ false, 8, "int");
+ TCLAP::ValueArg<int> decoder_thread_num(
+ "", "decoder-thread-num", "decoder thread num", false, 8, "int");
+ TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
+ "model thread num", false, 1, "int");
+
+ TCLAP::ValueArg<std::string> certfile("", "certfile",
+ "default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.",
+ false, "../../../ssl_key/server.crt", "string");
+ TCLAP::ValueArg<std::string> keyfile("", "keyfile",
+ "default: ../../../ssl_key/server.key, path of keyfile for WSS connection",
+ false, "../../../ssl_key/server.key", "string");
+
+ cmd.add(certfile);
+ cmd.add(keyfile);
+
+ cmd.add(download_model_dir);
+ cmd.add(model_dir);
+ cmd.add(model_revision);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_revision);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_revision);
+ cmd.add(punc_quant);
+
+ cmd.add(listen_ip);
+ cmd.add(port);
+ cmd.add(io_thread_num);
+ cmd.add(decoder_thread_num);
+ cmd.add(model_thread_num);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
+
+ GetValue(model_revision, "model-revision", model_path);
+ GetValue(vad_revision, "vad-revision", model_path);
+ GetValue(punc_revision, "punc-revision", model_path);
+
+ // Download model form Modelscope
+ try{
+ std::string s_download_model_dir = download_model_dir.getValue();
+
+ std::string s_vad_path = model_path[VAD_DIR];
+ std::string s_vad_quant = model_path[VAD_QUANT];
+ std::string s_asr_path = model_path[MODEL_DIR];
+ std::string s_asr_quant = model_path[QUANTIZE];
+ std::string s_punc_path = model_path[PUNC_DIR];
+ std::string s_punc_quant = model_path[PUNC_QUANT];
+
+ std::string python_cmd = "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
+
+ if(vad_dir.isSet() && !s_vad_path.empty()){
+ std::string python_cmd_vad;
+ std::string down_vad_path;
+ std::string down_vad_model;
+
+ if (access(s_vad_path.c_str(), F_OK) == 0){
+ // local
+ python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir ./ " + " --model_revision " + model_path["vad-revision"];
+ down_vad_path = s_vad_path;
+ }else{
+ // modelscope
+ LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
+ python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["vad-revision"];
+ down_vad_path = s_download_model_dir+"/"+s_vad_path;
+ }
+
+ int ret = system(python_cmd_vad.c_str());
+ if(ret !=0){
+ LOG(INFO) << "Failed to download model from modelscope. If you set local vad model path, you can ignore the errors.";
+ }
+ down_vad_model = down_vad_path+"/model_quant.onnx";
+ if(s_vad_quant=="false" || s_vad_quant=="False" || s_vad_quant=="FALSE"){
+ down_vad_model = down_vad_path+"/model.onnx";
+ }
+
+ if (access(down_vad_model.c_str(), F_OK) != 0){
+ LOG(ERROR) << down_vad_model << " do not exists.";
+ exit(-1);
+ }else{
+ model_path[VAD_DIR]=down_vad_path;
+ LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
+ }
+ }else{
+ LOG(INFO) << "VAD model is not set, use default.";
+ }
+
+ if(model_dir.isSet() && !s_asr_path.empty()){
+ std::string python_cmd_asr;
+ std::string down_asr_path;
+ std::string down_asr_model;
+
+ if (access(s_asr_path.c_str(), F_OK) == 0){
+ // local
+ python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"];
+ down_asr_path = s_asr_path;
+ }else{
+ // modelscope
+ LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
+ python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"];
+ down_asr_path = s_download_model_dir+"/"+s_asr_path;
+ }
+
+ int ret = system(python_cmd_asr.c_str());
+ if(ret !=0){
+ LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
+ }
+ down_asr_model = down_asr_path+"/model_quant.onnx";
+ if(s_asr_quant=="false" || s_asr_quant=="False" || s_asr_quant=="FALSE"){
+ down_asr_model = down_asr_path+"/model.onnx";
+ }
+
+ if (access(down_asr_model.c_str(), F_OK) != 0){
+ LOG(ERROR) << down_asr_model << " do not exists.";
+ exit(-1);
+ }else{
+ model_path[MODEL_DIR]=down_asr_path;
+ LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
+ }
+ }else{
+ LOG(INFO) << "ASR model is not set, use default.";
+ }
+
+ if(punc_dir.isSet() && !s_punc_path.empty()){
+ std::string python_cmd_punc;
+ std::string down_punc_path;
+ std::string down_punc_model;
+
+ if (access(s_punc_path.c_str(), F_OK) == 0){
+ // local
+ python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir ./ " + " --model_revision " + model_path["punc-revision"];
+ down_punc_path = s_punc_path;
+ }else{
+ // modelscope
+ LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: ";
+ python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["punc-revision"];
+ down_punc_path = s_download_model_dir+"/"+s_punc_path;
+ }
+
+ int ret = system(python_cmd_punc.c_str());
+ if(ret !=0){
+ LOG(INFO) << "Failed to download model from modelscope. If you set local punc model path, you can ignore the errors.";
+ }
+ down_punc_model = down_punc_path+"/model_quant.onnx";
+ if(s_punc_quant=="false" || s_punc_quant=="False" || s_punc_quant=="FALSE"){
+ down_punc_model = down_punc_path+"/model.onnx";
+ }
+
+ if (access(down_punc_model.c_str(), F_OK) != 0){
+ LOG(ERROR) << down_punc_model << " do not exists.";
+ exit(-1);
+ }else{
+ model_path[PUNC_DIR]=down_punc_path;
+ LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
+ }
+ }else{
+ LOG(INFO) << "PUNC model is not set, use default.";
+ }
+
+ } catch (std::exception const& e) {
+ LOG(ERROR) << "Error: " << e.what();
+ }
+
+ std::string s_listen_ip = listen_ip.getValue();
+ int s_port = port.getValue();
+ int s_io_thread_num = io_thread_num.getValue();
+ int s_decoder_thread_num = decoder_thread_num.getValue();
+
+ int s_model_thread_num = model_thread_num.getValue();
+
+ asio::io_context io_decoder; // context for decoding
+ asio::io_context io_server; // context for server
+
+ std::vector<std::thread> decoder_threads;
+
+ std::string s_certfile = certfile.getValue();
+ std::string s_keyfile = keyfile.getValue();
+
+ bool is_ssl = false;
+ if (!s_certfile.empty()) {
+ is_ssl = true;
+ }
+
+ auto conn_guard = asio::make_work_guard(
+ io_decoder); // make sure threads can wait in the queue
+ auto server_guard = asio::make_work_guard(
+ io_server); // make sure threads can wait in the queue
+ // create threads pool
+ for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
+ decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
+ }
+
+ server server_; // server for websocket
+ wss_server wss_server_;
+ if (is_ssl) {
+ wss_server_.init_asio(&io_server); // init asio
+ wss_server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
+
+ // list on port for accept
+ wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ WebSocketServer websocket_srv(
+ io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
+ s_keyfile); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+
+ } else {
+ server_.init_asio(&io_server); // init asio
+ server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
+
+ // list on port for accept
+ server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ WebSocketServer websocket_srv(
+ io_decoder, is_ssl, &server_, nullptr, s_certfile,
+ s_keyfile); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+ }
+
+ std::cout << "asr model init finished. listen on port:" << s_port
+ << std::endl;
+
+ // Start the ASIO network io_service run loop
+ std::vector<std::thread> ts;
+ // create threads for io network
+ for (size_t i = 0; i < s_io_thread_num; i++) {
+ ts.emplace_back([&io_server]() { io_server.run(); });
+ }
+ // wait for theads
+ for (size_t i = 0; i < s_io_thread_num; i++) {
+ ts[i].join();
+ }
+
+ // wait for theads
+ for (auto& t : decoder_threads) {
+ t.join();
+ }
+
+ } catch (std::exception const& e) {
+ std::cerr << "Error: " << e.what() << std::endl;
+ }
+
+ return 0;
+}
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
index d2a54e9..b67a905 100644
--- a/funasr/runtime/websocket/readme.md
+++ b/funasr/runtime/websocket/readme.md
@@ -5,15 +5,21 @@
```shell
# pip3 install torch torchaudio
-pip install -U modelscope funasr
+pip3 install -U modelscope funasr
# For the users in China, you could install with the command:
-# pip install -U modelscope funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U modelscope funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
```shell
-python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
+python -m funasr.export.export_model \
+--export-dir ./export \
+--type onnx \
+--quantize True \
+--model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch \
+--model-name damo/speech_fsmn_vad_zh-cn-16k-common-pytorch \
+--model-name damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch
```
## Building for Linux/Unix
@@ -36,10 +42,11 @@
required openssl lib
```shell
-#install openssl lib first
-apt-get install libssl-dev
+apt-get install libssl-dev #ubuntu
+# yum install openssl-devel #centos
-git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/websocket
+
+git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/funasr/runtime/websocket
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
make
@@ -48,56 +55,89 @@
```shell
cd bin
- ./websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
- [--io_thread_num <int>] [--port <int>] [--listen_ip
+./funasr-wss-server [--download-model-dir <string>]
+ [--model-thread-num <int>] [--decoder-thread-num <int>]
+ [--io-thread-num <int>] [--port <int>] [--listen_ip
<string>] [--punc-quant <string>] [--punc-dir <string>]
[--vad-quant <string>] [--vad-dir <string>] [--quantize
<string>] --model-dir <string> [--keyfile <string>]
[--certfile <string>] [--] [--version] [-h]
Where:
+ --download-model-dir <string>
+ Download model from Modelscope to download_model_dir
+
--model-dir <string>
- (required) the asr model path, which contains model.onnx, config.yaml, am.mvn
+ default: /workspace/models/asr, the asr model path, which contains model_quant.onnx, config.yaml, am.mvn
--quantize <string>
- false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+ true (Default), load the model of model_quant.onnx in model_dir. If set false, load the model of model.onnx in model_dir
--vad-dir <string>
- the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+ default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn
--vad-quant <string>
- false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+ true (Default), load the model of model_quant.onnx in vad_dir. If set false, load the model of model.onnx in vad_dir
--punc-dir <string>
- the punc model path, which contains model.onnx, punc.yaml
+ default: /workspace/models/punc, the punc model path, which contains model_quant.onnx, punc.yaml
--punc-quant <string>
- false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
+ true (Default), load the model of model_quant.onnx in punc_dir. If set false, load the model of model.onnx in punc_dir
- --decoder_thread_num <int>
+ --decoder-thread-num <int>
number of threads for decoder, default:8
- --io_thread_num <int>
+ --io-thread-num <int>
number of threads for network io, default:8
--port <int>
- listen port, default:8889
+ listen port, default:10095
--certfile <string>
- path of certficate for WSS connection. if it is empty, it will be in WS mode.
+ default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.
--keyfile <string>
- path of keyfile for WSS connection
+ default: ../../../ssl_key/server.key, path of keyfile for WSS connection
- Required: --model-dir <string>
- If use vad, please add: --vad-dir <string>
- If use punc, please add: --punc-dir <string>
example:
- websocketmain --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+# you can use models downloaded from modelscope or local models:
+# download models from modelscope
+./funasr-wss-server \
+ --download-model-dir /workspace/models \
+ --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+ --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+ --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+
+# load models from local paths
+./funasr-wss-server \
+ --model-dir /workspace/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+ --vad-dir /workspace/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+ --punc-dir /workspace/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
+
```
## Run websocket client test
```shell
-Usage: ./websocketclient server_ip port wav_path threads_num is_ssl
+./funasr-wss-client --server-ip <string>
+ --port <string>
+ --wav-path <string>
+ [--thread-num <int>]
+ [--is-ssl <int>] [--]
+ [--version] [-h]
-is_ssl is 1 means use wss connection, or use ws connection
+Where:
+ --server-ip <string>
+ (required) server-ip
+
+ --port <string>
+ (required) port
+
+ --wav-path <string>
+ (required) the input could be: wav_path, e.g.: asr_example.wav;
+ pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)
+
+ --thread-num <int>
+ thread-num
+
+ --is-ssl <int>
+ is-ssl is 1 means use wss connection, or use ws connection
example:
-
-websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64 0
+./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
result json, example like:
{"mode":"offline","text":"娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�","wav_name":"wav2"}
diff --git a/funasr/runtime/websocket/websocketsrv.cpp b/funasr/runtime/websocket/websocket-server.cpp
similarity index 90%
rename from funasr/runtime/websocket/websocketsrv.cpp
rename to funasr/runtime/websocket/websocket-server.cpp
index eb3c8db..a311c23 100644
--- a/funasr/runtime/websocket/websocketsrv.cpp
+++ b/funasr/runtime/websocket/websocket-server.cpp
@@ -10,7 +10,7 @@
// pools, one for handle network data and one for asr decoder.
// now only support offline engine.
-#include "websocketsrv.h"
+#include "websocket-server.h"
#include <thread>
#include <utility>
@@ -22,12 +22,11 @@
std::string& s_keyfile) {
namespace asio = websocketpp::lib::asio;
- std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
- std::cout << "using TLS mode: "
+ LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
+ LOG(INFO) << "using TLS mode: "
<< (mode == MOZILLA_MODERN ? "Mozilla Modern"
- : "Mozilla Intermediate")
- << std::endl;
-
+ : "Mozilla Intermediate");
+
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
asio::ssl::context::sslv23);
@@ -49,7 +48,7 @@
ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
} catch (std::exception& e) {
- std::cout << "Exception: " << e.what() << std::endl;
+ LOG(INFO) << "Exception: " << e.what();
}
return ctx;
}
@@ -86,8 +85,7 @@
ec);
}
- std::cout << "buffer.size=" << buffer.size()
- << ",result json=" << jsonresult.dump() << std::endl;
+ LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
if (!isonline) {
// close the client if it is not online asr
// server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
@@ -110,14 +108,14 @@
data_msg->samples = std::make_shared<std::vector<char>>();
data_msg->msg = nlohmann::json::parse("{}");
data_map.emplace(hdl, data_msg);
- std::cout << "on_open, active connections: " << data_map.size() << std::endl;
+ LOG(INFO) << "on_open, active connections: " << data_map.size();
}
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock);
data_map.erase(hdl); // remove data vector when connection is closed
- std::cout << "on_close, active connections: " << data_map.size() << std::endl;
+ LOG(INFO) << "on_close, active connections: " << data_map.size();
}
// remove closed connection
@@ -143,7 +141,7 @@
}
for (auto hdl : to_remove) {
data_map.erase(hdl);
- std::cout << "remove one connection " << std::endl;
+ LOG(INFO)<< "remove one connection ";
}
}
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
@@ -161,7 +159,7 @@
lock.unlock();
if (sample_data_p == nullptr) {
- std::cout << "error when fetch sample data vector" << std::endl;
+ LOG(INFO) << "error when fetch sample data vector";
return;
}
@@ -176,7 +174,7 @@
if (jsonresult["is_speaking"] == false ||
jsonresult["is_finished"] == true) {
- std::cout << "client done" << std::endl;
+ LOG(INFO) << "client done";
if (isonline) {
// do_close(ws);
@@ -225,9 +223,9 @@
// init model with api
asr_hanlde = FunOfflineInit(model_path, thread_num);
- std::cout << "model ready" << std::endl;
+ LOG(INFO) << "model successfully inited";
} catch (const std::exception& e) {
- std::cout << e.what() << std::endl;
+ LOG(INFO) << e.what();
}
}
diff --git a/funasr/runtime/websocket/websocketsrv.h b/funasr/runtime/websocket/websocket-server.h
similarity index 97%
rename from funasr/runtime/websocket/websocketsrv.h
rename to funasr/runtime/websocket/websocket-server.h
index 3cb8816..198af1c 100644
--- a/funasr/runtime/websocket/websocketsrv.h
+++ b/funasr/runtime/websocket/websocket-server.h
@@ -10,8 +10,8 @@
// pools, one for handle network data and one for asr decoder.
// now only support offline engine.
-#ifndef WEBSOCKETSRV_SERVER_H_
-#define WEBSOCKETSRV_SERVER_H_
+#ifndef WEBSOCKET_SERVER_H_
+#define WEBSOCKET_SERVER_H_
#include <iostream>
#include <map>
@@ -134,4 +134,4 @@
websocketpp::lib::mutex m_lock; // mutex for sample_map
};
-#endif // WEBSOCKETSRV_SERVER_H_
+#endif // WEBSOCKET_SERVER_H_
diff --git a/funasr/runtime/websocket/websocketclient.cpp b/funasr/runtime/websocket/websocketclient.cpp
deleted file mode 100644
index e9f8f1d..0000000
--- a/funasr/runtime/websocket/websocketclient.cpp
+++ /dev/null
@@ -1,277 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
- * Reserved. MIT License (https://opensource.org/licenses/MIT)
- */
-/* 2022-2023 by zhaomingwork */
-
-// client for websocket, support multiple threads
-// Usage: websocketclient server_ip port wav_path threads_num
-
-#define ASIO_STANDALONE 1
-#include <websocketpp/client.hpp>
-#include <websocketpp/common/thread.hpp>
-#include <websocketpp/config/asio_client.hpp>
-
-#include "audio.h"
-#include "nlohmann/json.hpp"
-
-/**
- * Define a semi-cross platform helper method that waits/sleeps for a bit.
- */
-void wait_a_bit() {
-#ifdef WIN32
- Sleep(1000);
-#else
- sleep(1);
-#endif
-}
-typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
-typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
- context_ptr;
-using websocketpp::lib::bind;
-using websocketpp::lib::placeholders::_1;
-using websocketpp::lib::placeholders::_2;
-context_ptr on_tls_init(websocketpp::connection_hdl) {
- context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
- asio::ssl::context::sslv23);
-
- try {
- ctx->set_options(
- asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
- asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
-
- } catch (std::exception& e) {
- std::cout << e.what() << std::endl;
- }
- return ctx;
-}
-// template for tls or not config
-template <typename T>
-class websocket_client {
- public:
- // typedef websocketpp::client<T> client;
- // typedef websocketpp::client<websocketpp::config::asio_tls_client>
- // wss_client;
- typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
-
- websocket_client(int is_ssl) : m_open(false), m_done(false) {
- // set up access channels to only log interesting things
-
- m_client.clear_access_channels(websocketpp::log::alevel::all);
- m_client.set_access_channels(websocketpp::log::alevel::connect);
- m_client.set_access_channels(websocketpp::log::alevel::disconnect);
- m_client.set_access_channels(websocketpp::log::alevel::app);
-
- // Initialize the Asio transport policy
- m_client.init_asio();
-
- // Bind the handlers we are using
- using websocketpp::lib::bind;
- using websocketpp::lib::placeholders::_1;
- m_client.set_open_handler(bind(&websocket_client::on_open, this, _1));
- m_client.set_close_handler(bind(&websocket_client::on_close, this, _1));
- m_client.set_close_handler(bind(&websocket_client::on_close, this, _1));
-
- m_client.set_message_handler(
- [this](websocketpp::connection_hdl hdl, message_ptr msg) {
- on_message(hdl, msg);
- });
-
- m_client.set_fail_handler(bind(&websocket_client::on_fail, this, _1));
- m_client.clear_access_channels(websocketpp::log::alevel::all);
- }
- void on_message(websocketpp::connection_hdl hdl, message_ptr msg) {
- const std::string& payload = msg->get_payload();
- switch (msg->get_opcode()) {
- case websocketpp::frame::opcode::text:
- std::cout << "on_message=" << payload << std::endl;
- }
- }
- // This method will block until the connection is complete
-
- void run(const std::string& uri, const std::string& wav_path) {
- // Create a new connection to the given URI
- websocketpp::lib::error_code ec;
- typename websocketpp::client<T>::connection_ptr con =
- m_client.get_connection(uri, ec);
- if (ec) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
- "Get Connection Error: " + ec.message());
- return;
- }
- this->wav_path = std::move(wav_path);
- // Grab a handle for this connection so we can talk to it in a thread
- // safe manor after the event loop starts.
- m_hdl = con->get_handle();
-
- // Queue the connection. No DNS queries or network connections will be
- // made until the io_service event loop is run.
- m_client.connect(con);
-
- // Create a thread to run the ASIO io_service event loop
- websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
- &m_client);
-
- send_wav_data();
- asio_thread.join();
- }
-
- // The open handler will signal that we are ready to start sending data
- void on_open(websocketpp::connection_hdl) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
- "Connection opened, starting data!");
-
- scoped_lock guard(m_lock);
- m_open = true;
- }
-
- // The close handler will signal that we should stop sending data
- void on_close(websocketpp::connection_hdl) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
- "Connection closed, stopping data!");
-
- scoped_lock guard(m_lock);
- m_done = true;
- }
-
- // The fail handler will signal that we should stop sending data
- void on_fail(websocketpp::connection_hdl) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
- "Connection failed, stopping data!");
-
- scoped_lock guard(m_lock);
- m_done = true;
- }
- // send wav to server
- void send_wav_data() {
- uint64_t count = 0;
- std::stringstream val;
-
- funasr::Audio audio(1);
- int32_t sampling_rate = 16000;
-
- if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate)) {
- std::cout << "error in load wav" << std::endl;
- return;
- }
-
- float* buff;
- int len;
- int flag = 0;
- bool wait = false;
- while (1) {
- {
- scoped_lock guard(m_lock);
- // If the connection has been closed, stop generating data
- if (m_done) {
- break;
- }
-
- // If the connection hasn't been opened yet wait a bit and retry
- if (!m_open) {
- wait = true;
- } else {
- break;
- }
- }
-
- if (wait) {
- std::cout << "wait.." << m_open << std::endl;
- wait_a_bit();
-
- continue;
- }
- }
- websocketpp::lib::error_code ec;
-
- nlohmann::json jsonbegin;
- nlohmann::json chunk_size = nlohmann::json::array();
- chunk_size.push_back(5);
- chunk_size.push_back(0);
- chunk_size.push_back(5);
- jsonbegin["chunk_size"] = chunk_size;
- jsonbegin["chunk_interval"] = 10;
- jsonbegin["wav_name"] = "damo";
- jsonbegin["is_speaking"] = true;
- m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
- ec);
-
- // fetch wav data use asr engine api
- while (audio.Fetch(buff, len, flag) > 0) {
- short iArray[len];
-
- // convert float -1,1 to short -32768,32767
- for (size_t i = 0; i < len; ++i) {
- iArray[i] = (short)(buff[i] * 32767);
- }
- // send data to server
- m_client.send(m_hdl, iArray, len * sizeof(short),
- websocketpp::frame::opcode::binary, ec);
- std::cout << "sended data len=" << len * sizeof(short) << std::endl;
- // The most likely error that we will get is that the connection is
- // not in the right state. Usually this means we tried to send a
- // message to a connection that was closed or in the process of
- // closing. While many errors here can be easily recovered from,
- // in this simple example, we'll stop the data loop.
- if (ec) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
- "Send Error: " + ec.message());
- break;
- }
-
- wait_a_bit();
- }
- nlohmann::json jsonresult;
- jsonresult["is_speaking"] = false;
- m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
- ec);
- wait_a_bit();
- }
- websocketpp::client<T> m_client;
-
- private:
- websocketpp::connection_hdl m_hdl;
- websocketpp::lib::mutex m_lock;
- std::string wav_path;
- bool m_open;
- bool m_done;
-};
-
-int main(int argc, char* argv[]) {
- if (argc < 6) {
- printf("Usage: %s server_ip port wav_path threads_num is_ssl\n", argv[0]);
- exit(-1);
- }
- std::string server_ip = argv[1];
- std::string port = argv[2];
- std::string wav_path = argv[3];
- int threads_num = atoi(argv[4]);
- int is_ssl = atoi(argv[5]);
- std::vector<websocketpp::lib::thread> client_threads;
- std::string uri = "";
- if (is_ssl == 1) {
- uri = "wss://" + server_ip + ":" + port;
- } else {
- uri = "ws://" + server_ip + ":" + port;
- }
-
- for (size_t i = 0; i < threads_num; i++) {
- client_threads.emplace_back([uri, wav_path, is_ssl]() {
- if (is_ssl == 1) {
- websocket_client<websocketpp::config::asio_tls_client> c(is_ssl);
-
- c.m_client.set_tls_init_handler(bind(&on_tls_init, ::_1));
-
- c.run(uri, wav_path);
- } else {
- websocket_client<websocketpp::config::asio_client> c(is_ssl);
-
- c.run(uri, wav_path);
- }
- });
- }
-
- for (auto& t : client_threads) {
- t.join();
- }
-}
\ No newline at end of file
diff --git a/funasr/runtime/websocket/websocketmain.cpp b/funasr/runtime/websocket/websocketmain.cpp
deleted file mode 100644
index 306c3f0..0000000
--- a/funasr/runtime/websocket/websocketmain.cpp
+++ /dev/null
@@ -1,190 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
- * Reserved. MIT License (https://opensource.org/licenses/MIT)
- */
-/* 2022-2023 by zhaomingwork */
-
-// io server
-// Usage:websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
-// [--io_thread_num <int>] [--port <int>] [--listen_ip
-// <string>] [--punc-quant <string>] [--punc-dir <string>]
-// [--vad-quant <string>] [--vad-dir <string>] [--quantize
-// <string>] --model-dir <string> [--] [--version] [-h]
-#include "websocketsrv.h"
-
-using namespace std;
-void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
- std::map<std::string, std::string>& model_path) {
- if (value_arg.isSet()) {
- model_path.insert({key, value_arg.getValue()});
- LOG(INFO) << key << " : " << value_arg.getValue();
- }
-}
-int main(int argc, char* argv[]) {
- try {
- google::InitGoogleLogging(argv[0]);
- FLAGS_logtostderr = true;
-
- TCLAP::CmdLine cmd("websocketmain", ' ', "1.0");
- TCLAP::ValueArg<std::string> model_dir(
- "", MODEL_DIR,
- "the asr model path, which contains model.onnx, config.yaml, am.mvn",
- true, "", "string");
- TCLAP::ValueArg<std::string> quantize(
- "", QUANTIZE,
- "false (Default), load the model of model.onnx in model_dir. If set "
- "true, load the model of model_quant.onnx in model_dir",
- false, "false", "string");
- TCLAP::ValueArg<std::string> vad_dir(
- "", VAD_DIR,
- "the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
- false, "", "string");
- TCLAP::ValueArg<std::string> vad_quant(
- "", VAD_QUANT,
- "false (Default), load the model of model.onnx in vad_dir. If set "
- "true, load the model of model_quant.onnx in vad_dir",
- false, "false", "string");
- TCLAP::ValueArg<std::string> punc_dir(
- "", PUNC_DIR,
- "the punc model path, which contains model.onnx, punc.yaml", false, "",
- "string");
- TCLAP::ValueArg<std::string> punc_quant(
- "", PUNC_QUANT,
- "false (Default), load the model of model.onnx in punc_dir. If set "
- "true, load the model of model_quant.onnx in punc_dir",
- false, "false", "string");
-
- TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false,
- "0.0.0.0", "string");
- TCLAP::ValueArg<int> port("", "port", "port", false, 8889, "int");
- TCLAP::ValueArg<int> io_thread_num("", "io_thread_num", "io_thread_num",
- false, 8, "int");
- TCLAP::ValueArg<int> decoder_thread_num(
- "", "decoder_thread_num", "decoder_thread_num", false, 8, "int");
- TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
- "model_thread_num", false, 1, "int");
-
- TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
- "string");
- TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
- "string");
-
- cmd.add(certfile);
- cmd.add(keyfile);
-
- cmd.add(model_dir);
- cmd.add(quantize);
- cmd.add(vad_dir);
- cmd.add(vad_quant);
- cmd.add(punc_dir);
- cmd.add(punc_quant);
-
- cmd.add(listen_ip);
- cmd.add(port);
- cmd.add(io_thread_num);
- cmd.add(decoder_thread_num);
- cmd.add(model_thread_num);
- cmd.parse(argc, argv);
-
- std::map<std::string, std::string> model_path;
- GetValue(model_dir, MODEL_DIR, model_path);
- GetValue(quantize, QUANTIZE, model_path);
- GetValue(vad_dir, VAD_DIR, model_path);
- GetValue(vad_quant, VAD_QUANT, model_path);
- GetValue(punc_dir, PUNC_DIR, model_path);
- GetValue(punc_quant, PUNC_QUANT, model_path);
-
- std::string s_listen_ip = listen_ip.getValue();
- int s_port = port.getValue();
- int s_io_thread_num = io_thread_num.getValue();
- int s_decoder_thread_num = decoder_thread_num.getValue();
-
- int s_model_thread_num = model_thread_num.getValue();
-
- asio::io_context io_decoder; // context for decoding
-
- std::vector<std::thread> decoder_threads;
-
- std::string s_certfile = certfile.getValue();
- std::string s_keyfile = keyfile.getValue();
-
- bool is_ssl = false;
- if (!s_certfile.empty()) {
- is_ssl = true;
- }
-
- auto conn_guard = asio::make_work_guard(
- io_decoder); // make sure threads can wait in the queue
-
- // create threads pool
- for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
- decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
- }
-
- server server_; // server for websocket
- wss_server wss_server_;
- if (is_ssl) {
- wss_server_.init_asio(); // init asio
- wss_server_.set_reuse_addr(
- true); // reuse address as we create multiple threads
-
- // list on port for accept
- wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
- WebSocketServer websocket_srv(
- io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
- s_keyfile); // websocket server for asr engine
- websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
-
- } else {
- server_.init_asio(); // init asio
- server_.set_reuse_addr(
- true); // reuse address as we create multiple threads
-
- // list on port for accept
- server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
- WebSocketServer websocket_srv(
- io_decoder, is_ssl, &server_, nullptr, s_certfile,
- s_keyfile); // websocket server for asr engine
- websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
- }
-
- std::cout << "asr model init finished. listen on port:" << s_port
- << std::endl;
-
- // Start the ASIO network io_service run loop
- if (s_io_thread_num == 1) {
- if (is_ssl) {
- wss_server_.run();
- } else {
- server_.run();
- }
- } else {
- typedef websocketpp::lib::shared_ptr<websocketpp::lib::thread> thread_ptr;
- std::vector<thread_ptr> ts;
- // create threads for io network
- for (size_t i = 0; i < s_io_thread_num; i++) {
- if (is_ssl) {
- ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
- &wss_server::run, &wss_server_));
- } else {
- ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
- &server::run, &server_));
- }
- }
- // wait for theads
- for (size_t i = 0; i < s_io_thread_num; i++) {
- ts[i]->join();
- }
- }
-
- // wait for theads
- for (auto& t : decoder_threads) {
- t.join();
- }
-
- } catch (std::exception const& e) {
- std::cerr << "Error: " << e.what() << std::endl;
- }
-
- return 0;
-}
\ No newline at end of file
diff --git a/funasr/samplers/build_batch_sampler.py b/funasr/samplers/build_batch_sampler.py
index 074b446..9266fea 100644
--- a/funasr/samplers/build_batch_sampler.py
+++ b/funasr/samplers/build_batch_sampler.py
@@ -4,8 +4,6 @@
from typing import Tuple
from typing import Union
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.samplers.abs_sampler import AbsSampler
from funasr.samplers.folded_batch_sampler import FoldedBatchSampler
@@ -104,7 +102,6 @@
padding: Whether sequences are input as a padded tensor or not.
used for "numel" mode
"""
- assert check_argument_types()
if len(shape_files) == 0:
raise ValueError("No shape file are given")
@@ -164,5 +161,4 @@
else:
raise ValueError(f"Not supported: {type}")
- assert check_return_type(retval)
return retval
diff --git a/funasr/samplers/folded_batch_sampler.py b/funasr/samplers/folded_batch_sampler.py
index 48e9604..f48d744 100644
--- a/funasr/samplers/folded_batch_sampler.py
+++ b/funasr/samplers/folded_batch_sampler.py
@@ -4,7 +4,6 @@
from typing import Tuple
from typing import Union
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.fileio.read_text import read_2column_text
@@ -23,7 +22,6 @@
drop_last: bool = False,
utt2category_file: str = None,
):
- assert check_argument_types()
assert batch_size > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(
diff --git a/funasr/samplers/length_batch_sampler.py b/funasr/samplers/length_batch_sampler.py
index 8ee8bdc..28404e3 100644
--- a/funasr/samplers/length_batch_sampler.py
+++ b/funasr/samplers/length_batch_sampler.py
@@ -4,7 +4,6 @@
from typing import Tuple
from typing import Union
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@@ -21,7 +20,6 @@
drop_last: bool = False,
padding: bool = True,
):
- assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(
diff --git a/funasr/samplers/num_elements_batch_sampler.py b/funasr/samplers/num_elements_batch_sampler.py
index 0ffad92..ebed0e4 100644
--- a/funasr/samplers/num_elements_batch_sampler.py
+++ b/funasr/samplers/num_elements_batch_sampler.py
@@ -4,7 +4,6 @@
from typing import Union
import numpy as np
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@@ -21,7 +20,6 @@
drop_last: bool = False,
padding: bool = True,
):
- assert check_argument_types()
assert batch_bins > 0
if sort_batch != "ascending" and sort_batch != "descending":
raise ValueError(
diff --git a/funasr/samplers/sorted_batch_sampler.py b/funasr/samplers/sorted_batch_sampler.py
index d6c3b41..b31e93e 100644
--- a/funasr/samplers/sorted_batch_sampler.py
+++ b/funasr/samplers/sorted_batch_sampler.py
@@ -2,7 +2,6 @@
from typing import Iterator
from typing import Tuple
-from typeguard import check_argument_types
from funasr.fileio.read_text import load_num_sequence_text
from funasr.samplers.abs_sampler import AbsSampler
@@ -26,7 +25,6 @@
sort_batch: str = "ascending",
drop_last: bool = False,
):
- assert check_argument_types()
assert batch_size > 0
self.batch_size = batch_size
self.shape_file = shape_file
diff --git a/funasr/samplers/unsorted_batch_sampler.py b/funasr/samplers/unsorted_batch_sampler.py
index 349e526..e5ed05b 100644
--- a/funasr/samplers/unsorted_batch_sampler.py
+++ b/funasr/samplers/unsorted_batch_sampler.py
@@ -2,7 +2,6 @@
from typing import Iterator
from typing import Tuple
-from typeguard import check_argument_types
from funasr.fileio.read_text import read_2column_text
from funasr.samplers.abs_sampler import AbsSampler
@@ -28,7 +27,6 @@
drop_last: bool = False,
utt2category_file: str = None,
):
- assert check_argument_types()
assert batch_size > 0
self.batch_size = batch_size
self.key_file = key_file
diff --git a/funasr/schedulers/noam_lr.py b/funasr/schedulers/noam_lr.py
index 80df019..e08fb63 100644
--- a/funasr/schedulers/noam_lr.py
+++ b/funasr/schedulers/noam_lr.py
@@ -4,7 +4,6 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler
-from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
@@ -31,7 +30,6 @@
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
- assert check_argument_types()
self.model_size = model_size
self.warmup_steps = warmup_steps
diff --git a/funasr/schedulers/tri_stage_scheduler.py b/funasr/schedulers/tri_stage_scheduler.py
index 8dc71b4..c442260 100644
--- a/funasr/schedulers/tri_stage_scheduler.py
+++ b/funasr/schedulers/tri_stage_scheduler.py
@@ -8,7 +8,6 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler
-from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
@@ -22,7 +21,6 @@
init_lr_scale: float = 0.01,
final_lr_scale: float = 0.01,
):
- assert check_argument_types()
self.optimizer = optimizer
self.last_epoch = last_epoch
self.phase_ratio = phase_ratio
diff --git a/funasr/schedulers/warmup_lr.py b/funasr/schedulers/warmup_lr.py
index dbf3aca..95ebaca 100644
--- a/funasr/schedulers/warmup_lr.py
+++ b/funasr/schedulers/warmup_lr.py
@@ -3,7 +3,6 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler
-from typeguard import check_argument_types
from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
@@ -30,7 +29,6 @@
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
- assert check_argument_types()
self.warmup_steps = warmup_steps
# __init__() must be invoked before setting field
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 0fb77a9..91d33c5 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -32,8 +32,6 @@
import yaml
from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr import __version__
from funasr.datasets.dataset import AbsDataset
@@ -269,7 +267,6 @@
@classmethod
def get_parser(cls) -> config_argparse.ArgumentParser:
- assert check_argument_types()
class ArgumentDefaultsRawTextHelpFormatter(
argparse.RawTextHelpFormatter,
@@ -959,7 +956,6 @@
cls.trainer.add_arguments(parser)
cls.add_task_arguments(parser)
- assert check_return_type(parser)
return parser
@classmethod
@@ -1007,7 +1003,6 @@
return _cls
# This method is used only for --print_config
- assert check_argument_types()
parser = cls.get_parser()
args, _ = parser.parse_known_args()
config = vars(args)
@@ -1047,7 +1042,6 @@
@classmethod
def check_required_command_args(cls, args: argparse.Namespace):
- assert check_argument_types()
if hasattr(args, "required"):
for k in vars(args):
if "-" in k:
@@ -1077,7 +1071,6 @@
inference: bool = False,
) -> None:
"""Check if the dataset satisfy the requirement of current Task"""
- assert check_argument_types()
mes = (
f"If you intend to use an additional input, modify "
f'"{cls.__name__}.required_data_names()" or '
@@ -1104,14 +1097,12 @@
@classmethod
def print_config(cls, file=sys.stdout) -> None:
- assert check_argument_types()
# Shows the config: e.g. python train.py asr --print_config
config = cls.get_default_config()
file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
@classmethod
def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
- assert check_argument_types()
print(get_commandline_args(), file=sys.stderr)
if args is None:
parser = cls.get_parser()
@@ -1148,7 +1139,6 @@
@classmethod
def main_worker(cls, args: argparse.Namespace):
- assert check_argument_types()
# 0. Init distributed process
distributed_option = build_dataclass(DistributedOption, args)
@@ -1556,7 +1546,6 @@
- 4 epoch with "--num_iters_per_epoch" == 4
"""
- assert check_argument_types()
iter_options = cls.build_iter_options(args, distributed_option, mode)
# Overwrite iter_options if any kwargs is given
@@ -1589,7 +1578,6 @@
def build_sequence_iter_factory(
cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
) -> AbsIterFactory:
- assert check_argument_types()
if hasattr(args, "frontend_conf"):
if args.frontend_conf is not None and "fs" in args.frontend_conf:
@@ -1683,7 +1671,6 @@
iter_options: IteratorOptions,
mode: str,
) -> AbsIterFactory:
- assert check_argument_types()
dataset = ESPnetDataset(
iter_options.data_path_and_name_and_type,
@@ -1788,7 +1775,6 @@
def build_multiple_iter_factory(
cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
):
- assert check_argument_types()
iter_options = cls.build_iter_options(args, distributed_option, mode)
assert len(iter_options.data_path_and_name_and_type) > 0, len(
iter_options.data_path_and_name_and_type
@@ -1885,7 +1871,6 @@
inference: bool = False,
) -> DataLoader:
"""Build DataLoader using iterable dataset"""
- assert check_argument_types()
# For backward compatibility for pytorch DataLoader
if collate_fn is not None:
kwargs = dict(collate_fn=collate_fn)
@@ -1935,7 +1920,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 92333ab..3ab68df 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -13,8 +13,6 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
@@ -38,6 +36,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -45,6 +44,7 @@
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -54,6 +54,7 @@
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@@ -134,6 +135,7 @@
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
+ sa_asr=SAASRModel,
),
type_check=FunASRModel,
default="asr",
@@ -175,6 +177,27 @@
type_check=AbsEncoder,
default="rnn",
)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
@@ -197,6 +220,7 @@
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@@ -330,6 +354,12 @@
help="whether to split text using <space>",
)
group.add_argument(
+ "--max_spk_num",
+ type=int_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
"--seg_dict_file",
type=str,
default=None,
@@ -459,7 +489,6 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -467,7 +496,6 @@
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -497,7 +525,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -516,12 +543,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -626,7 +651,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
@@ -669,7 +693,6 @@
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -806,7 +829,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -828,7 +850,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
@@ -943,7 +964,6 @@
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -1053,7 +1073,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -1075,7 +1094,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
@@ -1178,7 +1196,6 @@
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -1276,7 +1293,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
@@ -1301,7 +1317,6 @@
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -1356,7 +1371,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
@classmethod
@@ -1393,7 +1407,6 @@
Return:
model: ASR Transducer model.
"""
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@@ -1492,6 +1505,123 @@
"Initialization part will be reworked in a short future.",
)
- #assert check_return_type(model)
+
+ return model
+
+
+class ASRTaskSAASR(ASRTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 5. Encoder
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # 8. CTC
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+ )
+
+ # import ipdb;ipdb.set_trace()
+ # 9. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
return model
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
index 9a64e1f..b11d7de 100644
--- a/funasr/tasks/data2vec.py
+++ b/funasr/tasks/data2vec.py
@@ -8,8 +8,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
@@ -256,14 +254,12 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
return CommonCollateFn(clipping=True)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -289,7 +285,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -305,12 +300,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
# 1. frontend
if args.input_size is None:
@@ -372,5 +365,4 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 2625fec..a486a46 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -21,8 +21,6 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
@@ -344,7 +342,6 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -352,7 +349,6 @@
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -382,7 +378,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -401,12 +396,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -505,7 +498,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -528,7 +520,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
@@ -764,7 +755,6 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -772,7 +762,6 @@
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
# if args.use_preprocessor:
# retval = CommonPreprocessor(
# train=train,
@@ -802,7 +791,6 @@
# )
# else:
# retval = None
- # assert check_return_type(retval)
return None
@classmethod
@@ -821,12 +809,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
# 1. frontend
if args.input_size is None or args.frontend == "wav_frontend_mel23":
@@ -865,7 +851,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -888,7 +873,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 44fdf8e..c0259a8 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -9,8 +9,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
@@ -52,7 +50,6 @@
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): Use '_' instead of '-' to avoid confusion
- assert check_argument_types()
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
@@ -130,7 +127,6 @@
for class_choices in cls.class_choices_list:
class_choices.add_arguments(group)
- assert check_return_type(parser)
return parser
@classmethod
@@ -140,14 +136,12 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
return CommonCollateFn(int_pad_value=0)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -160,7 +154,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -179,7 +172,6 @@
@classmethod
def build_model(cls, args: argparse.Namespace) -> LanguageModel:
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index a63bbe4..de5c897 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -9,8 +9,6 @@
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
@@ -47,7 +45,6 @@
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): Use '_' instead of '-' to avoid confusion
- assert check_argument_types()
group = parser.add_argument_group(description="Task related")
# NOTE(kamo): add_arguments(..., required=True) can't be used
@@ -126,7 +123,6 @@
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
- assert check_return_type(parser)
return parser
@classmethod
@@ -136,14 +132,12 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
return CommonCollateFn(int_pad_value=0)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
token_types = [args.token_type, args.token_type]
token_lists = [args.token_list, args.punc_list]
bpemodels = [args.bpemodel, args.bpemodel]
@@ -161,7 +155,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -182,7 +175,6 @@
@classmethod
def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -223,5 +215,4 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index 4769758..e7ee5a3 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -13,8 +13,6 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
@@ -39,7 +37,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
@@ -120,7 +118,7 @@
model_choices = ClassChoices(
"model",
classes=dict(
- asr=ESPnetASRModel,
+ asr=SAASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
@@ -445,7 +443,6 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -453,7 +450,6 @@
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -483,7 +479,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -502,12 +497,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -619,5 +612,4 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
- return model
+ return model
\ No newline at end of file
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index e4815da..e698522 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -17,8 +17,6 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
@@ -273,7 +271,6 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -281,7 +278,6 @@
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -309,7 +305,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -330,12 +325,10 @@
retval = ()
if inference:
retval = ("ref_speech",)
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetSVModel:
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -449,7 +442,6 @@
if args.init is not None:
initialize(model, args.init)
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -472,7 +464,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index ec95596..822be22 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -13,8 +13,6 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.layers.abs_normalize import AbsNormalize
@@ -192,7 +190,6 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -200,7 +197,6 @@
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
# if args.use_preprocessor:
# retval = CommonPreprocessor(
# train=train,
@@ -223,7 +219,6 @@
# else:
# retval = None
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -242,12 +237,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
# 4. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(**args.encoder_conf)
@@ -297,7 +290,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
diff --git a/funasr/text/build_tokenizer.py b/funasr/text/build_tokenizer.py
index 8e29d3e..c60a335 100644
--- a/funasr/text/build_tokenizer.py
+++ b/funasr/text/build_tokenizer.py
@@ -2,7 +2,6 @@
from typing import Iterable
from typing import Union
-from typeguard import check_argument_types
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.text.char_tokenizer import CharTokenizer
@@ -21,7 +20,6 @@
g2p_type: str = None,
) -> AbsTokenizer:
"""A helper function to instantiate Tokenizer"""
- assert check_argument_types()
if token_type == "bpe":
if bpemodel is None:
raise ValueError('bpemodel is required if token_type = "bpe"')
diff --git a/funasr/text/char_tokenizer.py b/funasr/text/char_tokenizer.py
index 00ae427..8d1daf4 100644
--- a/funasr/text/char_tokenizer.py
+++ b/funasr/text/char_tokenizer.py
@@ -4,7 +4,6 @@
from typing import Union
import warnings
-from typeguard import check_argument_types
from funasr.text.abs_tokenizer import AbsTokenizer
@@ -16,7 +15,6 @@
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
- assert check_argument_types()
self.space_symbol = space_symbol
if non_linguistic_symbols is None:
self.non_linguistic_symbols = set()
diff --git a/funasr/text/cleaner.py b/funasr/text/cleaner.py
index be26940..6322672 100644
--- a/funasr/text/cleaner.py
+++ b/funasr/text/cleaner.py
@@ -2,7 +2,6 @@
from jaconv import jaconv
import tacotron_cleaner.cleaners
-from typeguard import check_argument_types
try:
from vietnamese_cleaner import vietnamese_cleaners
@@ -21,7 +20,6 @@
"""
def __init__(self, cleaner_types: Collection[str] = None):
- assert check_argument_types()
if cleaner_types is None:
self.cleaner_types = []
diff --git a/funasr/text/phoneme_tokenizer.py b/funasr/text/phoneme_tokenizer.py
index d424b40..ad3d81c 100644
--- a/funasr/text/phoneme_tokenizer.py
+++ b/funasr/text/phoneme_tokenizer.py
@@ -9,7 +9,6 @@
# import g2p_en
import jamo
-from typeguard import check_argument_types
from funasr.text.abs_tokenizer import AbsTokenizer
@@ -365,7 +364,6 @@
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
- assert check_argument_types()
if g2p_type is None:
self.g2p = split_by_space
elif g2p_type == "g2p_en":
diff --git a/funasr/text/sentencepiece_tokenizer.py b/funasr/text/sentencepiece_tokenizer.py
index e4cc152..e393cee 100644
--- a/funasr/text/sentencepiece_tokenizer.py
+++ b/funasr/text/sentencepiece_tokenizer.py
@@ -4,14 +4,12 @@
from typing import Union
import sentencepiece as spm
-from typeguard import check_argument_types
from funasr.text.abs_tokenizer import AbsTokenizer
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
- assert check_argument_types()
self.model = str(model)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
diff --git a/funasr/text/token_id_converter.py b/funasr/text/token_id_converter.py
index c9a6b28..1888d75 100644
--- a/funasr/text/token_id_converter.py
+++ b/funasr/text/token_id_converter.py
@@ -5,7 +5,6 @@
from typing import Union
import numpy as np
-from typeguard import check_argument_types
class TokenIDConverter:
@@ -14,7 +13,6 @@
token_list: Union[Path, str, Iterable[str]],
unk_symbol: str = "<unk>",
):
- assert check_argument_types()
if isinstance(token_list, (Path, str)):
token_list = Path(token_list)
diff --git a/funasr/text/word_tokenizer.py b/funasr/text/word_tokenizer.py
index 842734e..f4d33d5 100644
--- a/funasr/text/word_tokenizer.py
+++ b/funasr/text/word_tokenizer.py
@@ -4,7 +4,6 @@
from typing import Union
import warnings
-from typeguard import check_argument_types
from funasr.text.abs_tokenizer import AbsTokenizer
@@ -16,7 +15,6 @@
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
):
- assert check_argument_types()
self.delimiter = delimiter
if not remove_non_linguistic_symbols and non_linguistic_symbols is not None:
diff --git a/funasr/torch_utils/forward_adaptor.py b/funasr/torch_utils/forward_adaptor.py
index 114af78..eb6da2b 100644
--- a/funasr/torch_utils/forward_adaptor.py
+++ b/funasr/torch_utils/forward_adaptor.py
@@ -1,5 +1,4 @@
import torch
-from typeguard import check_argument_types
class ForwardAdaptor(torch.nn.Module):
@@ -21,7 +20,6 @@
"""
def __init__(self, module: torch.nn.Module, name: str):
- assert check_argument_types()
super().__init__()
self.module = module
self.name = name
diff --git a/funasr/torch_utils/initialize.py b/funasr/torch_utils/initialize.py
index 2c0e7a4..e4ec534 100644
--- a/funasr/torch_utils/initialize.py
+++ b/funasr/torch_utils/initialize.py
@@ -4,7 +4,6 @@
import math
import torch
-from typeguard import check_argument_types
def initialize(model: torch.nn.Module, init: str):
@@ -19,7 +18,6 @@
model: Target.
init: Method of initialization.
"""
- assert check_argument_types()
if init == "chainer":
# 1. lecun_normal_init_parameters
diff --git a/funasr/train/abs_model.py b/funasr/train/abs_model.py
index 8d684be..9687376 100644
--- a/funasr/train/abs_model.py
+++ b/funasr/train/abs_model.py
@@ -8,7 +8,6 @@
import torch
import torch.nn.functional as F
-from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
@@ -34,7 +33,6 @@
class LanguageModel(FunASRModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
- assert check_argument_types()
super().__init__()
self.lm = lm
self.sos = 1
@@ -154,7 +152,6 @@
class PunctuationModel(FunASRModel):
def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
- assert check_argument_types()
super().__init__()
self.punc_model = punc_model
self.punc_weight = torch.Tensor(punc_weight)
diff --git a/funasr/train/class_choices.py b/funasr/train/class_choices.py
index 658d291..1ffb97a 100644
--- a/funasr/train/class_choices.py
+++ b/funasr/train/class_choices.py
@@ -2,8 +2,6 @@
from typing import Optional
from typing import Tuple
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import str_or_none
@@ -40,7 +38,6 @@
default: str = None,
optional: bool = False,
):
- assert check_argument_types()
self.name = name
self.base_type = type_check
self.classes = {k.lower(): v for k, v in classes.items()}
@@ -64,12 +61,10 @@
return retval
def get_class(self, name: Optional[str]) -> Optional[type]:
- assert check_argument_types()
if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
retval = None
elif name.lower() in self.classes:
class_obj = self.classes[name]
- assert check_return_type(class_obj)
retval = class_obj
else:
raise ValueError(
diff --git a/funasr/train/reporter.py b/funasr/train/reporter.py
index 2921fef..cfe31f5 100644
--- a/funasr/train/reporter.py
+++ b/funasr/train/reporter.py
@@ -18,8 +18,6 @@
import humanfriendly
import numpy as np
import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
Num = Union[float, int, complex, torch.Tensor, np.ndarray]
@@ -27,7 +25,6 @@
def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue":
- assert check_argument_types()
if isinstance(v, (torch.Tensor, np.ndarray)):
if np.prod(v.shape) != 1:
raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}")
@@ -42,12 +39,10 @@
retval = WeightedAverage(v, weight)
else:
retval = Average(v)
- assert check_return_type(retval)
return retval
def aggregate(values: Sequence["ReportedValue"]) -> Num:
- assert check_argument_types()
for v in values:
if not isinstance(v, type(values[0])):
@@ -86,7 +81,6 @@
else:
raise NotImplementedError(f"type={type(values[0])}")
- assert check_return_type(retval)
return retval
@@ -122,7 +116,6 @@
"""
def __init__(self, key: str, epoch: int, total_count: int):
- assert check_argument_types()
self.key = key
self.epoch = epoch
self.start_time = time.perf_counter()
@@ -160,7 +153,6 @@
stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]],
weight: Num = None,
) -> None:
- assert check_argument_types()
if self._finished:
raise RuntimeError("Already finished")
if len(self._seen_keys_in_the_step) == 0:
@@ -293,7 +285,6 @@
"""
def __init__(self, epoch: int = 0):
- assert check_argument_types()
if epoch < 0:
raise ValueError(f"epoch must be 0 or more: {epoch}")
self.epoch = epoch
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index f066909..a25f39a 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -26,7 +26,6 @@
import torch
import torch.nn
import torch.optim
-from typeguard import check_argument_types
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.main_funcs.average_nbest_models import average_nbest_models
@@ -127,7 +126,6 @@
@classmethod
def build_options(cls, args: argparse.Namespace) -> TrainerOptions:
"""Build options consumed by train(), eval()"""
- assert check_argument_types()
return build_dataclass(TrainerOptions, args)
@classmethod
@@ -188,7 +186,6 @@
distributed_option: DistributedOption,
) -> None:
"""Perform training. This method performs the main process of training."""
- assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
assert is_dataclass(trainer_options), type(trainer_options)
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
@@ -551,7 +548,6 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> Tuple[bool, bool]:
- assert check_argument_types()
grad_noise = options.grad_noise
accum_grad = options.accum_grad
@@ -845,7 +841,6 @@
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
- assert check_argument_types()
ngpu = options.ngpu
no_forward_run = options.no_forward_run
distributed = distributed_option.distributed
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
index 4067b04..5aa40ec 100644
--- a/funasr/utils/asr_utils.py
+++ b/funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
+import soundfile
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
if support_audio_type == "pcm":
fs = None
else:
- audio, fs = torchaudio.load(fname)
+ try:
+ audio, fs = torchaudio.load(fname)
+ except:
+ audio, fs = soundfile.read(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:
diff --git a/funasr/utils/build_dataclass.py b/funasr/utils/build_dataclass.py
index 6675c99..0d59ad9 100644
--- a/funasr/utils/build_dataclass.py
+++ b/funasr/utils/build_dataclass.py
@@ -1,7 +1,6 @@
import argparse
import dataclasses
-from typeguard import check_type
def build_dataclass(dataclass, args: argparse.Namespace):
@@ -12,6 +11,5 @@
raise ValueError(
f"args doesn't have {field.name}. You need to set it to ArgumentsParser"
)
- check_type(field.name, getattr(args, field.name), field.type)
kwargs[field.name] = getattr(args, field.name)
return dataclass(**kwargs)
diff --git a/funasr/utils/griffin_lim.py b/funasr/utils/griffin_lim.py
index c1536d5..9e98ab8 100644
--- a/funasr/utils/griffin_lim.py
+++ b/funasr/utils/griffin_lim.py
@@ -9,7 +9,6 @@
from distutils.version import LooseVersion
from functools import partial
-from typeguard import check_argument_types
from typing import Optional
import librosa
@@ -138,7 +137,6 @@
griffin_lim_iters: The number of iterations.
"""
- assert check_argument_types()
self.fs = fs
self.logmel2linear = (
partial(
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index 7602740..0e773bb 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
import numpy as np
import torch.distributed as dist
import torchaudio
+import soundfile
def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
def wav2num_frame(wav_path, frontend_conf):
- waveform, sampling_rate = torchaudio.load(wav_path)
+ try:
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ except:
+ waveform, sampling_rate = soundfile.read(wav_path)
+ waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
return n_frames, feature_dim
diff --git a/funasr/utils/runtime_sdk_download_tool.py b/funasr/utils/runtime_sdk_download_tool.py
new file mode 100644
index 0000000..f8d4bc9
--- /dev/null
+++ b/funasr/utils/runtime_sdk_download_tool.py
@@ -0,0 +1,39 @@
+from pathlib import Path
+import os
+import argparse
+from funasr.utils.types import str2bool
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model-name', type=str, required=True)
+parser.add_argument('--export-dir', type=str, required=True)
+parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+args = parser.parse_args()
+
+model_dir = args.model_name
+if not Path(args.model_name).exists():
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+ (model_dir)
+
+model_file = os.path.join(model_dir, 'model.onnx')
+if args.quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+if not os.path.exists(model_file):
+ print(".onnx is not exist, begin to export onnx")
+ from funasr.export.export_model import ModelExport
+ export_model = ModelExport(
+ cache_dir=args.export_dir,
+ onnx=True,
+ device="cpu",
+ quant=args.quantize,
+ )
+ export_model.export(model_dir)
\ No newline at end of file
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index ebb80d2..bd067c2 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
import torchaudio.compliance.kaldi as kaldi
@@ -162,7 +163,13 @@
waveform = torch.from_numpy(waveform.reshape(1, -1))
else:
# load pcm from wav, and resample
- waveform, audio_sr = torchaudio.load(wav_file)
+ try:
+ waveform, audio_sr = torchaudio.load(wav_file)
+ except:
+ waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
+ if waveform.ndim == 2:
+ waveform = waveform[:, 0]
+ waveform = torch.tensor(np.expand_dims(waveform, axis=0))
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
@@ -181,7 +188,11 @@
def wav2num_frame(wav_path, frontend_conf):
- waveform, sampling_rate = torchaudio.load(wav_path)
+ try:
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ except:
+ waveform, sampling_rate = soundfile.read(wav_path)
+ waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
diff --git a/funasr/version.txt b/funasr/version.txt
index 659914a..2228cad 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.5.8
+0.6.7
diff --git a/setup.py b/setup.py
index 0e787ab..6bb4bcd 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,6 @@
import os
-from distutils.version import LooseVersion
from setuptools import find_packages
from setuptools import setup
@@ -12,13 +11,10 @@
requirements = {
"install": [
"setuptools>=38.5.1",
- # "configargparse>=1.2.1",
- "typeguard==2.13.3",
"humanfriendly",
"scipy>=1.4.1",
- # "filelock",
"librosa",
- "jamo==0.4.1", # For kss
+ "jamo", # For kss
"PyYAML>=5.1.2",
"soundfile>=0.10.2",
"h5py>=2.10.0",
@@ -27,58 +23,32 @@
"nltk>=3.4.5",
# ASR
"sentencepiece",
- # "ctc-segmentation<1.8,>=1.6.6",
# TTS
- # "pyworld>=0.2.10",
- "pypinyin<=0.44.0",
+ "pypinyin>=0.44.0",
"espnet_tts_frontend",
# ENH
- # "ci_sdr",
"pytorch_wpe",
- "editdistance==0.5.2",
- "tensorboard==1.15",
+ "editdistance>=0.5.2",
+ "tensorboard",
"g2p",
# PAI
"oss2",
- # "kaldi-native-fbank",
- # timestamp
"edit-distance",
- # textgrid
"textgrid",
- "protobuf==3.20.0",
+ "protobuf",
],
# train: The modules invoked when training only.
"train": [
- # "pillow>=6.1.0",
- "editdistance==0.5.2",
+ "editdistance",
"wandb",
- ],
- # recipe: The modules actually are not invoked in the main module of funasr,
- # but are invoked for the python scripts in each recipe
- "recipe": [
- "espnet_model_zoo",
- # "gdown",
- # "resampy",
- # "pysptk>=0.1.17",
- # "morfessor", # for zeroth-korean
- # "youtube_dl", # for laborotv
- # "nnmnkwii",
- # "museval>=0.2.1",
- # "pystoi>=0.2.2",
- # "mir-eval>=0.6",
- # "fastdtw",
- # "nara_wpe>=0.0.5",
- # "sacrebleu>=1.5.1",
],
# all: The modules should be optionally installled due to some reason.
# Please consider moving them to "install" occasionally
- # NOTE(kamo): The modules in "train" and "recipe" are appended into "all"
"all": [
# NOTE(kamo): Append modules requiring specific pytorch version or torch>1.3.0
"torch_optimizer",
"fairscale",
"transformers",
- # "gtn==0.0.0",
],
"setup": [
"numpy",
@@ -98,17 +68,18 @@
"black",
],
"doc": [
- "Jinja2<3.1",
- "Sphinx==2.1.2",
+ "Jinja2",
+ "Sphinx",
"sphinx-rtd-theme>=0.2.4",
"sphinx-argparse>=0.2.5",
- "commonmark==0.8.1",
+ "commonmark",
"recommonmark>=0.4.0",
"nbsphinx>=0.4.2",
"sphinx-markdown-tables>=0.0.12",
+ "configargparse>=1.2.1"
],
}
-requirements["all"].extend(requirements["train"] + requirements["recipe"])
+requirements["all"].extend(requirements["train"])
requirements["test"].extend(requirements["train"])
install_requires = requirements["install"]
@@ -151,4 +122,4 @@
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Libraries :: Python Modules",
],
-)
+)
\ No newline at end of file
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index 9098ea6..2b21acf 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -87,6 +87,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "鍥藉姟闄㈠彂灞曠爺绌朵腑蹇冨競鍦虹粡娴庣爺绌舵墍鍓墍闀块倱閮佹澗璁や负"
def test_paraformer_large_aishell1(self):
inference_pipeline = pipeline(
@@ -95,6 +96,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_large_aishell2(self):
inference_pipeline = pipeline(
@@ -103,6 +105,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_large_common(self):
inference_pipeline = pipeline(
@@ -111,6 +114,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_large_online_common(self):
inference_pipeline = pipeline(
@@ -119,6 +123,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶� 瀹舵潵 浣撻獙杈� 鎽╅櫌鎺� 鍑虹殑 璇煶璇� 鍒ā 鍨�"
def test_paraformer_online_common(self):
inference_pipeline = pipeline(
@@ -127,6 +132,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋 澶у鏉� 浣撻獙杈� 鎽╅櫌鎺� 鍑虹殑 璇煶璇� 鍒ā 鍨�"
def test_paraformer_tiny_commandword(self):
inference_pipeline = pipeline(
diff --git a/tests/test_asr_vad_punc_inference_pipeline.py b/tests/test_asr_vad_punc_inference_pipeline.py
index 628b256..f86f23d 100644
--- a/tests/test_asr_vad_punc_inference_pipeline.py
+++ b/tests/test_asr_vad_punc_inference_pipeline.py
@@ -26,6 +26,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr_vad_punc inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨嬨��"
if __name__ == '__main__':
diff --git a/tests/test_sv_inference_pipeline.py b/tests/test_sv_inference_pipeline.py
index 54ab564..c4e427e 100644
--- a/tests/test_sv_inference_pipeline.py
+++ b/tests/test_sv_inference_pipeline.py
@@ -19,30 +19,20 @@
task=Tasks.speaker_verification,
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
)
- # 鎻愬彇涓嶅悓鍙ュ瓙鐨勮璇濅汉宓屽叆鐮�
- rec_result = inference_sv_pipline(
- audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
- enroll = rec_result["spk_embedding"]
- rec_result = inference_sv_pipline(
- audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav')
- same = rec_result["spk_embedding"]
-
- rec_result = inference_sv_pipline(
- audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
- different = rec_result["spk_embedding"]
-
- # 瀵圭浉鍚岀殑璇磋瘽浜鸿绠椾綑寮︾浉浼煎害
- sv_threshold = 0.9465
- same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
- same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
- logger.info("Similarity: {}".format(same_cos))
-
- # 瀵逛笉鍚岀殑璇磋瘽浜鸿绠椾綑寮︾浉浼煎害
- diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
- diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
- logger.info("Similarity: {}".format(diff_cos))
-
+ # the same speaker
+ rec_result = inference_sv_pipline(audio_in=(
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
+ assert abs(rec_result["scores"][0]-0.85) < 0.1 and abs(rec_result["scores"][1]-0.14) < 0.1
+ logger.info(f"Similarity {rec_result['scores']}")
+
+ # different speaker
+ rec_result = inference_sv_pipline(audio_in=(
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
+ 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
+ assert abs(rec_result["scores"][0]-0.0) < 0.1 and abs(rec_result["scores"][1]-1.0) < 0.1
+ logger.info(f"Similarity {rec_result['scores']}")
if __name__ == '__main__':
- unittest.main()
+ unittest.main()
\ No newline at end of file
diff --git a/tests/test_tp_pipeline.py b/tests/test_tp_pipeline.py
new file mode 100644
index 0000000..07084f2
--- /dev/null
+++ b/tests/test_tp_pipeline.py
@@ -0,0 +1,30 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+class TestTimestampPredictionPipelines(unittest.TestCase):
+ def test_funasr_path(self):
+ import funasr
+ import os
+ logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+ def test_inference_pipeline(self):
+ inference_pipeline = pipeline(
+ task=Tasks.speech_timestamp,
+ model='damo/speech_timestamp_prediction-v1-16k-offline',
+ model_revision='v1.1.0')
+
+ rec_result = inference_pipeline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
+ text_in='涓� 涓� 涓� 澶� 骞� 娲� 鍥� 瀹� 涓� 浠� 涔� 璺� 鍒� 瑗� 澶� 骞� 娲� 鏉� 浜� 鍛�',)
+ print(rec_result)
+ logger.info("punctuation inference result: {0}".format(rec_result))
+ assert rec_result=={'text': '<sil> 0.000 0.380;涓� 0.380 0.560;涓� 0.560 0.800;涓� 0.800 0.980;澶� 0.980 1.140;骞� 1.140 1.260;娲� 1.260 1.440;鍥� 1.440 1.680;瀹� 1.680 1.920;<sil> 1.920 2.040;涓� 2.040 2.200;浠� 2.200 2.320;涔� 2.320 2.500;璺� 2.500 2.680;鍒� 2.680 2.860;瑗� 2.860 3.040;澶� 3.040 3.200;骞� 3.200 3.380;娲� 3.380 3.500;鏉� 3.500 3.640;浜� 3.640 3.800;鍛� 3.800 4.150;<sil> 4.150 4.440;', 'timestamp': [[380, 560], [560, 800], [800, 980], [980, 1140], [1140, 1260], [1260, 1440], [1440, 1680], [1680, 1920], [2040, 2200], [2200, 2320], [2320, 2500], [2500, 2680], [2680, 2860], [2860, 3040], [3040, 3200], [3200, 3380], [3380, 3500], [3500, 3640], [3640, 3800], [3800, 4150]]}
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/test_vad_inference_pipeline.py b/tests/test_vad_inference_pipeline.py
index b6601b1..50b8db3 100644
--- a/tests/test_vad_inference_pipeline.py
+++ b/tests/test_vad_inference_pipeline.py
@@ -37,7 +37,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
logger.info("vad inference result: {0}".format(rec_result))
- assert rec_result["text"] == [[80, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
+ assert rec_result["text"] == [[70, 2340], [2620, 6200], [6480, 23670], [23950, 26250], [26780, 28990],
[29950, 31430], [31750, 37600], [38210, 46900], [47310, 49630], [49910, 56460],
[56740, 59540], [59820, 70450]]
--
Gitblit v1.9.1