zhifu gao
2023-04-16 937e507977cc9e49ce323f8b2933087d0fe52698
Merge pull request #363 from alibaba-damo-academy/main

update with main
100个文件已修改
7个文件已删除
19个文件已添加
1 文件已复制
7 文件已重命名
6004 ■■■■ 已修改文件
README.md 37 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/modelscope_models.md 93 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/transformer/utils/compute_wer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py 93 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_streaming.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_rnnt.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr_vad.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/diar_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/lm_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_train_vadrealtime.py 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer_vadrealtime.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sv_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/dataset.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/tokenize.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/preprocessor.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py 52 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/CT_Transformer.py 162 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_asr_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_vad.py 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/encoder/fsmn_encoder.py 296 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/encoder/sanm_encoder.py 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx_punc.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx_punc_vadrealtime.py 22 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx_vad.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_torchscripts.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/lm/abs_model.py 129 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/lm/espnet_model.py 131 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/contextual_decoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/sanm_decoder.py 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/transformer_decoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_mfcca.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_diar_sond.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_sv.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_tp.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_uni_asr.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_vad.py 35 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/opennmt_encoders/conv_encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/resnet34_encoder.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/sanm_encoder.py 236 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/predictor/cif.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/target_delay_transformer.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/vad_realtime_transformer.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/streaming_utils/chunk_utilis.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/abs_model.py 31 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/sanm_encoder.py 590 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/punctuation/text_preprocessor.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/grpc/CMakeLists.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/grpc/Readme.md 62 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/grpc/paraformer_server.cc 73 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/grpc/paraformer_server.h 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/CMakeLists.txt 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/Audio.h 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/libfunasrapi.h 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/librapidasrapi.h 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/readme.md 99 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/Audio.cpp 262 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/CMakeLists.txt 29 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/FeatureExtract.cpp 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/FeatureExtract.h 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/Vocab.cpp 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/commonfunc.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/libfunasrapi.cpp 184 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/librapidasrapi.cpp 182 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer_onnx.h 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/resample.cc 305 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/resample.h 137 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/tester/CMakeLists.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/tester/tester.cpp 58 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/tester/tester_rtf.cpp 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/grpc/grpc_main_client.py 62 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/grpc/grpc_server.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/libtorch/README.md 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/libtorch/setup.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/README.md 15 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_punc_offline.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_punc_online.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_vad_offline.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_vad_online.py 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py 259 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py 616 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py 184 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py 59 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py 280 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/setup.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/utils/test_rtf_gpu.py 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/lm.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/punctuation.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/sv.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/vad.py 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train/abs_model.py 54 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/compute_wer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/version.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_asr_inference_pipeline.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_punctuation_pipeline.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -7,12 +7,11 @@
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new) 
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
| [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md)
| [**Contact**](#contact)
@@ -29,15 +28,37 @@
## Installation
``` sh
pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install --editable ./
Install from pip
```shell
pip install -U funasr
# For the users in China, you could install with the command:
# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install -e ./
# For the users in China, you could install with the command:
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
If you want to use the pretrained models in ModelScope, you should install the modelscope:
```shell
pip install -U modelscope
# For the users in China, you could install with the command:
# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
## Usage
For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
[//]: # ()
[//]: # (## Usage)
[//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)))
## Contact
docs/modelscope_models.md
@@ -6,29 +6,70 @@
## Model Zoo
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
| Datasets  | Hours |     Model      | Online/Offline | Language | Framework | Checkpoint |
|:-----:|:-----:|:--------------:|:--------------:| :---: | :---: | --- |
| Alibaba Speech Data | 60000 |   Paraformer   |   Offline   |       CN       | Pytorch |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) |
| Alibaba Speech Data | 50000 |   Paraformer   |   Offline   |       CN       | Tensorflow |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
| Alibaba Speech Data | 50000 |   Paraformer   |   Offline   |       CN       | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
| Alibaba Speech Data | 50000 |   Paraformer   |   Online    |       CN       | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online](http://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online/summary) |
| Alibaba Speech Data | 50000 |    UniASR     |   Online    |       CN       | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 50000 |    UniASR     |   Offline   |       CN       | Tensorflow |[speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 50000 |    UniASR     |   Online    |     CN&EN      | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 50000 |    UniASR     |   Offline   |     CN&EN      | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 20000 |    UniASR     |   Online    |   CN-Accent    | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 20000 |    UniASR     |    Offline     |   CN-Accent    | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 30000 | Paraformer-8K |     Online     |       CN       | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online/summary) |
| Alibaba Speech Data |  30000   | Paraformer-8K |    Offline     |       CN       | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/summary) |
| Alibaba Speech Data |  30000   | Paraformer-8K |     Online     |       CN       | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
| Alibaba Speech Data |  30000   | Paraformer-8K |    Offline     |       CN       | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
| Alibaba Speech Data |  30000   |   UniASR-8K   |     Online     |       CN       | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data |  30000   |   UniASR-8K   |    Offline     |       CN       | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data |  30000   |   UniASR-8K   |     Online     |       CN       | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
| Alibaba Speech Data |  30000   |   UniASR-8K   |    Offline     |       CN       | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
| AISHELL-1 |  178  |   Paraformer   | Offline |       CN       | Pytorch | [speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) |
| AISHELL-2 | 1000  |   Paraformer   |   Offline   |       CN       | Pytorch | [speech_paraformer_asr_nat-aishell2-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell2-pytorch/summary) |
| AISHELL-1 |  178  | ParaformerBert |   Offline   |       CN       | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
| AISHELL-2 | 1000  | ParaformerBert |   Offline   |       CN       | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
| AISHELL-1 |  178  |   Conformer   |    Offline     |       CN       | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
| AISHELL-2 | 1000  |   Conformer   |    Offline     |       CN       | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
### Speech Recognition Models
#### Paraformer Models
|                                                                     Model Name                                                                     | Language |          Training Data           | Vocab Size | Parameter | Offline/Online | Notes                                                                                                                           |
|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
|        [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)        | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Duration of input wav <= 20s                                                                                                    |
| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Which ould deal with arbitrary length input wav                                                                                 |
| [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN  | Alibaba Speech Data (60000hours) |    8404    |   220M    |    Offline     | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
|              [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary)              | CN & EN  | Alibaba Speech Data (50000hours) |    8358    |    68M    |    Offline     | Duration of input wav <= 20s                                                                                                    |
|          [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary)           | CN & EN  | Alibaba Speech Data (50000hours) |    8404    |    68M    |     Online     | Which could deal with streaming input                                                                                           |
|       [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary)       |    CN    |  Alibaba Speech Data (200hours)  |    544     |   5.2M    |    Offline     | Lightweight Paraformer model which supports Mandarin command words recognition                                                  |
|                   [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary)                   |    CN    |        AISHELL (178hours)        |    4234    |    43M    |    Offline     |                                                                                                                                 |
|       [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary)       |    CN    |        AISHELL (178hours)        |    4234    |    43M    |    Offline     |                                                                                                                                 |
|        [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary)         |    CN    |      AISHELL-2 (1000hours)       |    5212    |    64M    |    Offline     |                                                                                                                                 |
|    [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary)     |    CN    |      AISHELL-2 (1000hours)       |    5212    |    64M    |    Offline     |                                                                                                                                 |
#### UniASR Models
|                                                               Model Name                                                               | Language |          Training Data           | Vocab Size | Parameter | Offline/Online | Notes                                                                                                                           |
|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
|       [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary)        | CN & EN  | Alibaba Speech Data (60000hours) |    8358    |   100M    |     Online     | UniASR streaming offline unifying models                                                                                                    |
| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN  | Alibaba Speech Data (60000hours) |    8358    |   220M    |    Offline     | UniASR streaming offline unifying models                                                                                                    |
|           [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary)           | Burmese  |  Alibaba Speech Data (? hours)   |    696     |    95M    |     Online     | UniASR streaming offline unifying models                                                                                                    |
|           [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary)           |  Hebrew  |  Alibaba Speech Data (? hours)   |    1085    |    95M    |     Online     | UniASR streaming offline unifying models                                                                                                    |
|       [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary)                  |   Urdu   |  Alibaba Speech Data (? hours)   |    877     |    95M    |     Online     | UniASR streaming offline unifying models                                                                                                    |
#### Conformer Models
#### Paraformer Models
|                                                       Model Name                                                       | Language |     Training Data     | Vocab Size | Parameter | Offline/Online | Notes                                                                                                                           |
|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary)   |   CN     |  AISHELL (178hours)   |    4234    |    44M    |    Offline     | Duration of input wav <= 20s                                                                                                    |
| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary)   |   CN     | AISHELL-2 (1000hours) |    5212    |    44M    |    Offline     | Duration of input wav <= 20s                                                                                                    |
#### RNN-T Models
### Voice Activity Detection Models
|                                           Model Name                                           |        Training Data         | Parameters | Sampling Rate | Notes |
|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) |    0.4M    |     16000     |       |
|   [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary)        | Alibaba Speech Data (5000hours) |    0.4M    |     8000      |       |
### Punctuation Restoration Models
|                                                         Model Name                                                         |        Training Data         | Parameters | Vocab Size| Offline/Online | Notes |
|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
|      [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)      | Alibaba Text Data |    70M     |    272727     |    Offline     |   offline punctuation model    |
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary)      | Alibaba Text Data |    70M     |    272727     |     Online     |  online punctuation model     |
### Language Models
|                                                       Model Name                                                       |        Training Data         | Parameters | Vocab Size | Notes |
|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary)      | Alibaba Speech Data (?hours) |    57M     |    8404    |       |
### Speaker Verification Models
|                                                  Model Name                                                   |   Training Data   | Parameters | Vocab Size | Notes |
|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (?hours)  |   17.5M    |    3465    |       |
| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (?hours) |    61M     |    6135    |       |
### Speaker diarization Models
|                                                    Model Name                                                    |    Training Data    | Parameters | Notes |
|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (?hours) |   40.5M    |       |
| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary)    |  CallHome (?hours)  |     12M     |       |
egs/aishell/transformer/utils/compute_wer.py
@@ -45,8 +45,8 @@
           if out_item['wrong'] > 0:
               rst['wrong_sentences'] += 1
           cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
           cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
           cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
           cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
           cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
    if rst['Wrd'] > 0:
        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
@@ -74,7 +74,7 @@
    # If text exists, compute CER
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(best_recog_path, "token")
        text_proc_file = os.path.join(best_recog_path, "text")
        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
@@ -38,7 +38,7 @@
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        text_proc_file = os.path.join(decoding_path, "1best_recog/text")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
@@ -74,7 +74,7 @@
    # If text exists, compute CER
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(best_recog_path, "token")
        text_proc_file = os.path.join(best_recog_path, "text")
        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py
@@ -38,7 +38,7 @@
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        text_proc_file = os.path.join(decoding_path, "1best_recog/text")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
@@ -63,8 +63,8 @@
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
    echo "Computing WER ..."
    python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
    tail -n 3 ${output_dir}/1best_recog/text.cer
fi
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
@@ -34,7 +34,7 @@
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        text_proc_file = os.path.join(decoding_path, "1best_recog/text")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh
@@ -63,8 +63,8 @@
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
    echo "Computing WER ..."
    python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
    tail -n 3 ${output_dir}/1best_recog/text.cer
fi
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py
@@ -34,7 +34,7 @@
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        text_proc_file = os.path.join(decoding_path, "1best_recog/text")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
@@ -23,8 +23,7 @@
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
    inference_pipline(audio_in=audio_in)
def modelscope_infer(params):
    # prepare for multi-GPU decoding
@@ -75,7 +74,7 @@
    # If text exists, compute CER
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(best_recog_path, "token")
        text_proc_file = os.path.join(best_recog_path, "text")
        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
@@ -2,52 +2,103 @@
import os
import shutil
from multiprocessing import Pool
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
def modelscope_infer_after_finetune_core(model_dir, output_dir, split_dir, njob, idx):
    output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
    gpu_id = (int(idx) - 1) // njob
    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model=model_dir,
        output_dir=output_dir_job,
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipeline(audio_in=audio_in)
def modelscope_infer_after_finetune(params):
    # prepare for decoding
    # prepare for multi-GPU decoding
    model_dir = params["model_dir"]
    pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
    for file_name in params["required_files"]:
        if file_name == "configuration.json":
            with open(os.path.join(pretrained_model_path, file_name)) as f:
                config_dict = json.load(f)
                config_dict["model"]["am_model_name"] = params["decoding_model_name"]
            with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
            with open(os.path.join(model_dir, "configuration.json"), "w") as f:
                json.dump(config_dict, f, indent=4, separators=(',', ': '))
        else:
            shutil.copy(os.path.join(pretrained_model_path, file_name),
                        os.path.join(params["output_dir"], file_name))
    decoding_path = os.path.join(params["output_dir"], "decode_results")
    if os.path.exists(decoding_path):
        shutil.rmtree(decoding_path)
    os.mkdir(decoding_path)
                        os.path.join(model_dir, file_name))
    ngpu = params["ngpu"]
    njob = params["njob"]
    output_dir = params["output_dir"]
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.mkdir(output_dir)
    split_dir = os.path.join(output_dir, "split")
    os.mkdir(split_dir)
    nj = ngpu * njob
    wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
    with open(wav_scp_file) as f:
        lines = f.readlines()
        num_lines = len(lines)
        num_job_lines = num_lines // nj
    start = 0
    for i in range(nj):
        end = start + num_job_lines
        file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
        with open(file, "w") as f:
            if i == nj - 1:
                f.writelines(lines[start:])
            else:
                f.writelines(lines[start:end])
        start = end
    # decoding
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model=params["output_dir"],
        output_dir=decoding_path,
        batch_size=1
    )
    audio_in = os.path.join(params["data_dir"], "wav.scp")
    inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
    p = Pool(nj)
    for i in range(nj):
        p.apply_async(modelscope_infer_after_finetune_core,
                      args=(model_dir, output_dir, split_dir, njob, str(i + 1)))
    p.close()
    p.join()
    # computer CER if GT text is set
    # combine decoding results
    best_recog_path = os.path.join(output_dir, "1best_recog")
    os.mkdir(best_recog_path)
    files = ["text", "token", "score"]
    for file in files:
        with open(os.path.join(best_recog_path, file), "w") as f:
            for i in range(nj):
                job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
                with open(job_file) as f_job:
                    lines = f_job.readlines()
                f.writelines(lines)
    # If text exists, compute CER
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
        text_proc_file = os.path.join(best_recog_path, "token")
        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
if __name__ == '__main__':
    params = {}
    params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline"
    params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
    params["output_dir"] = "./checkpoint"
    params["model_dir"] = "./checkpoint"
    params["output_dir"] = "./results"
    params["data_dir"] = "./data/test"
    params["decoding_model_name"] = "20epoch.pb"
    params["ngpu"] = 1
    params["njob"] = 1
    modelscope_infer_after_finetune(params)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
@@ -75,7 +75,7 @@
    # If text exists, compute CER
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(best_recog_path, "token")
        text_proc_file = os.path.join(best_recog_path, "text")
        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
@@ -39,7 +39,7 @@
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        text_proc_file = os.path.join(decoding_path, "1best_recog/text")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
File was renamed from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
@@ -1,3 +1,9 @@
"""
Author: Speech Lab, Alibaba Group, China
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
copy from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py copy to egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
File was copied from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
@@ -1,3 +1,9 @@
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
"""
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -8,17 +14,18 @@
    num_workers=0,
    task=Tasks.speaker_diarization,
    diar_model_config="sond.yaml",
    model='damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch',
    sv_model="damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch",
    model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
    sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
    sv_model_revision="master",
)
# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音
audio_list = [
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav"
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav",
    "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav",
]
results = inference_diar_pipline(audio_in=audio_list)
funasr/bin/asr_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/bin/asr_inference_paraformer.py
@@ -797,7 +797,7 @@
                        finish_count += 1
                        # asr_utils.print_progress(finish_count / file_count)
                        if writer is not None:
                            ibest_writer["text"][key] = text_postprocessed
                            ibest_writer["text"][key] = " ".join(word_lists)
                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
funasr/bin/asr_inference_paraformer_streaming.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -607,17 +608,21 @@
    ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
        is_final = False
        if param_dict is not None and "cache" in param_dict:
            cache = param_dict["cache"]
        if param_dict is not None and "is_final" in param_dict:
            is_final = param_dict["is_final"]
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            is_final = True
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        asr_result_list = []
funasr/bin/asr_inference_paraformer_vad.py
@@ -338,7 +338,7 @@
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["vad"][key] = "{}".format(vadsegments)
                    ibest_writer["text"][key] = text_postprocessed
                    ibest_writer["text"][key] = " ".join(word_lists)
                    ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                    if time_stamp_postprocessed is not None:
                        ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -670,7 +670,7 @@
                    ibest_writer["token"][key] = " ".join(token)
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["vad"][key] = "{}".format(vadsegments)
                    ibest_writer["text"][key] = text_postprocessed
                    ibest_writer["text"][key] = " ".join(word_lists)
                    ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                    if time_stamp_postprocessed is not None:
                        ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
funasr/bin/asr_inference_rnnt.py
@@ -738,13 +738,13 @@
                        ibest_writer["rtf"][key] = rtf_cur
                    if text is not None:
                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                        text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                        item = {'key': key, 'value': text_postprocessed}
                        asr_result_list.append(item)
                        finish_count += 1
                        # asr_utils.print_progress(finish_count / file_count)
                        if writer is not None:
                            ibest_writer["text"][key] = text_postprocessed
                            ibest_writer["text"][key] = " ".join(word_lists)
                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
funasr/bin/asr_inference_uniasr.py
@@ -37,9 +37,6 @@
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
    """Speech2Text class
@@ -507,13 +504,13 @@
                    ibest_writer["score"][key] = str(hyp.score)
    
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text_postprocessed
                        ibest_writer["text"][key] = " ".join(word_lists)
        return asr_result_list
    
    return _forward
funasr/bin/asr_inference_uniasr_vad.py
@@ -507,13 +507,13 @@
                    ibest_writer["score"][key] = str(hyp.score)
    
                if text is not None:
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text_postprocessed
                        ibest_writer["text"][key] = " ".join(word_lists)
        return asr_result_list
    
    return _forward
funasr/bin/diar_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/bin/lm_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/bin/punc_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/bin/punc_train_vadrealtime.py
File was deleted
funasr/bin/punctuation_infer.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@@ -90,7 +90,7 @@
            data = {
                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
                "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
                "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
            }
            data = to_device(data, self.device)
            y, _ = self.wrapped_model(**data)
funasr/bin/sv_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/bin/tp_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/bin/vad_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os
funasr/datasets/dataset.py
@@ -115,7 +115,7 @@
    # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
    # like Kaldi e.g. "cat a.wav |".
    # NOTE(kamo): The audio signal is normalized to [-1,1] range.
    loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
    loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
    # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
    # but ndarray is desired, so Adapter class is inserted here
funasr/datasets/large_datasets/utils/tokenize.py
@@ -47,8 +47,8 @@
    length = len(text)
    for i in range(length):
        x = text[i]
        if i == length-1 and "punc" in data and text[i].startswith("vad:"):
            vad = x[-1][4:]
        if i == length-1 and "punc" in data and x.startswith("vad:"):
            vad = x[4:]
            if len(vad) == 0:
                vad = -1
            else:
funasr/datasets/preprocessor.py
@@ -786,6 +786,7 @@
    ) -> Dict[str, np.ndarray]:
        for i in range(self.num_tokenizer):
            text_name = self.text_name[i]
            #import pdb; pdb.set_trace()
            if text_name in data and self.tokenizer[i] is not None:
                text = data[text_name]
                text = self.text_cleaner(text)
@@ -800,3 +801,17 @@
                    data[self.vad_name] = np.array([vad], dtype=np.int64)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
        return data
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
        return [words]
    sentences = []
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
    return sentences
funasr/export/README.md
@@ -7,7 +7,7 @@
## Install modelscope and funasr
The installation is the same as [funasr](../../README.md)
The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
## Export model
   `Tips`: torch>=1.11.0
funasr/export/export_model.py
@@ -167,31 +167,57 @@
    
    def export(self,
               tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
               mode: str = 'paraformer',
               mode: str = None,
               ):
        
        model_dir = tag_name
        if model_dir.startswith('damo/'):
        if model_dir.startswith('damo'):
            from modelscope.hub.snapshot_download import snapshot_download
            model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
        asr_train_config = os.path.join(model_dir, 'config.yaml')
        asr_model_file = os.path.join(model_dir, 'model.pb')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        json_file = os.path.join(model_dir, 'configuration.json')
        if mode is None:
            import json
            json_file = os.path.join(model_dir, 'configuration.json')
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                mode = config_data['model']['model_config']['mode']
                if config_data['task'] == "punctuation":
                    mode = config_data['model']['punc_model_config']['mode']
                else:
                    mode = config_data['model']['model_config']['mode']
        if mode.startswith('paraformer'):
            from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        elif mode.startswith('uniasr'):
            from funasr.tasks.asr import ASRTaskUniASR as ASRTask
            config = os.path.join(model_dir, 'config.yaml')
            model_file = os.path.join(model_dir, 'model.pb')
            cmvn_file = os.path.join(model_dir, 'am.mvn')
            model, asr_train_args = ASRTask.build_model_from_file(
                config, model_file, cmvn_file, 'cpu'
            )
            self.frontend = model.frontend
        elif mode.startswith('offline'):
            from funasr.tasks.vad import VADTask
            config = os.path.join(model_dir, 'vad.yaml')
            model_file = os.path.join(model_dir, 'vad.pb')
            cmvn_file = os.path.join(model_dir, 'vad.mvn')
            
        model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, 'cpu'
        )
        self.frontend = model.frontend
            model, vad_infer_args = VADTask.build_model_from_file(
                config, model_file, cmvn_file=cmvn_file, device='cpu'
            )
            self.export_config["feats_dim"] = 400
            self.frontend = model.frontend
        elif mode.startswith('punc'):
            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
            punc_train_config = os.path.join(model_dir, 'config.yaml')
            punc_model_file = os.path.join(model_dir, 'punc.pb')
            model, punc_train_args = PUNCTask.build_model_from_file(
                punc_train_config, punc_model_file, 'cpu'
            )
        elif mode.startswith('punc_VadRealtime'):
            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
            punc_train_config = os.path.join(model_dir, 'config.yaml')
            punc_model_file = os.path.join(model_dir, 'punc.pb')
            model, punc_train_args = PUNCTask.build_model_from_file(
                punc_train_config, punc_model_file, 'cpu'
            )
        self._export(model, tag_name)
            
funasr/export/models/CT_Transformer.py
New file
@@ -0,0 +1,162 @@
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.models.encoder.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class CT_Transformer(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(
            self,
            model,
            max_seq_len=512,
            model_name='punc_model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        self.embed = model.embed
        self.decoder = model.decoder
        # self.model = model
        self.feats_dim = self.embed.embedding_dim
        self.num_embeddings = self.embed.num_embeddings
        self.model_name = model_name
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        else:
            assert False, "Only support samn encode."
    def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(inputs)
        # mask = self._target_mask(input)
        h, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
    def get_input_names(self):
        return ['inputs', 'text_lengths']
    def get_output_names(self):
        return ['logits']
    def get_dynamic_axes(self):
        return {
            'inputs': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'text_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
class CT_Transformer_VadRealtime(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(
        self,
        model,
        max_seq_len=512,
        model_name='punc_model',
        **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        self.embed = model.embed
        if isinstance(model.encoder, SANMVadEncoder):
            self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
        else:
            assert False, "Only support samn encode."
        self.decoder = model.decoder
        self.model_name = model_name
    def forward(self, inputs: torch.Tensor,
                text_lengths: torch.Tensor,
                vad_indexes: torch.Tensor,
                sub_masks: torch.Tensor,
                ) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(inputs)
        # mask = self._target_mask(input)
        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
        y = self.decoder(h)
        return y
    def with_vad(self):
        return True
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
        text_lengths = torch.tensor([length], dtype=torch.int32)
        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
        sub_masks = torch.ones(length, length, dtype=torch.float32)
        sub_masks = torch.tril(sub_masks).type(torch.float32)
        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
    def get_input_names(self):
        return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
    def get_output_names(self):
        return ['logits']
    def get_dynamic_axes(self):
        return {
            'inputs': {
                1: 'feats_length'
            },
            'vad_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'sub_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'logits': {
                1: 'logits_length'
            },
        }
funasr/export/models/__init__.py
@@ -1,13 +1,25 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
        return BiCifParaformer_export(model, **export_config)
    elif isinstance(model, Paraformer):
        return Paraformer_export(model, **export_config)
    elif isinstance(model, E2EVadModel):
        return E2EVadModel_export(model, **export_config)
    elif isinstance(model, PunctuationModel):
        if isinstance(model.punc_model, TargetDelayTransformer):
            return CT_Transformer_export(model.punc_model, **export_config)
        elif isinstance(model.punc_model, VadRealtimeTransformer):
            return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
    else:
        raise "Funasr does not support the given model type currently."
        raise "Funasr does not support the given model type currently."
funasr/export/models/e2e_asr_paraformer.py
@@ -19,7 +19,7 @@
class Paraformer(nn.Module):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
@@ -112,7 +112,7 @@
class BiCifParaformer(nn.Module):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
funasr/export/models/e2e_vad.py
New file
@@ -0,0 +1,60 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import torch
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
class E2EVadModel(nn.Module):
    def __init__(self, model,
                max_seq_len=512,
                feats_dim=400,
                model_name='model',
                **kwargs,):
        super(E2EVadModel, self).__init__()
        self.feats_dim = feats_dim
        self.max_seq_len = max_seq_len
        self.model_name = model_name
        if isinstance(model.encoder, FSMN):
            self.encoder = FSMN_export(model.encoder)
        else:
            raise "unsupported encoder"
    def forward(self, feats: torch.Tensor, *args, ):
        scores, out_caches = self.encoder(feats, *args)
        return scores, out_caches
    def get_dummy_inputs(self, frame=30):
        speech = torch.randn(1, frame, self.feats_dim)
        in_cache0 = torch.randn(1, 128, 19, 1)
        in_cache1 = torch.randn(1, 128, 19, 1)
        in_cache2 = torch.randn(1, 128, 19, 1)
        in_cache3 = torch.randn(1, 128, 19, 1)
        return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
    # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
    #     import numpy as np
    #     fbank = np.loadtxt(txt_file)
    #     fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
    #     speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
    #     speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
    #     return (speech, speech_lengths)
    def get_input_names(self):
        return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
    def get_output_names(self):
        return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
    def get_dynamic_axes(self):
        return {
            'speech': {
                1: 'feats_length'
            },
        }
funasr/export/models/encoder/fsmn_encoder.py
New file
@@ -0,0 +1,296 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.encoder.fsmn_encoder import BasicBlock
class LinearTransform(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearTransform, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = nn.Linear(input_dim, output_dim, bias=False)
    def forward(self, input):
        output = self.linear(input)
        return output
class AffineTransform(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(AffineTransform, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, input):
        output = self.linear(input)
        return output
class RectifiedLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(RectifiedLinear, self).__init__()
        self.dim = input_dim
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
    def forward(self, input):
        out = self.relu(input)
        return out
class FSMNBlock(nn.Module):
    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            lorder=None,
            rorder=None,
            lstride=1,
            rstride=1,
    ):
        super(FSMNBlock, self).__init__()
        self.dim = input_dim
        if lorder is None:
            return
        self.lorder = lorder
        self.rorder = rorder
        self.lstride = lstride
        self.rstride = rstride
        self.conv_left = nn.Conv2d(
            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
        if self.rorder > 0:
            self.conv_right = nn.Conv2d(
                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
        else:
            self.conv_right = None
    def forward(self, input: torch.Tensor, cache: torch.Tensor):
        x = torch.unsqueeze(input, 1)
        x_per = x.permute(0, 3, 2, 1)  # B D T C
        cache = cache.to(x_per.device)
        y_left = torch.cat((cache, x_per), dim=2)
        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
        y_left = self.conv_left(y_left)
        out = x_per + y_left
        if self.conv_right is not None:
            # maybe need to check
            y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
            y_right = y_right[:, :, self.rstride:, :]
            y_right = self.conv_right(y_right)
            out += y_right
        out_per = out.permute(0, 3, 2, 1)
        output = out_per.squeeze(1)
        return output, cache
class BasicBlock_export(nn.Module):
    def __init__(self,
                 model,
                 ):
        super(BasicBlock_export, self).__init__()
        self.linear = model.linear
        self.fsmn_block = model.fsmn_block
        self.affine = model.affine
        self.relu = model.relu
    def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
        x = self.linear(input)  # B T D
        # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
        # if cache_layer_name not in in_cache:
        #     in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
        x, out_cache = self.fsmn_block(x, in_cache)
        x = self.affine(x)
        x = self.relu(x)
        return x, out_cache
# class FsmnStack(nn.Sequential):
#     def __init__(self, *args):
#         super(FsmnStack, self).__init__(*args)
#
#     def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
#         x = input
#         for module in self._modules.values():
#             x = module(x, in_cache)
#         return x
'''
FSMN net for keyword spotting
input_dim:              input dimension
linear_dim:             fsmn input dimensionll
proj_dim:               fsmn projection dimension
lorder:                 fsmn left order
rorder:                 fsmn right order
num_syn:                output dimension
fsmn_layers:            no. of sequential fsmn layers
'''
class FSMN(nn.Module):
    def __init__(
            self, model,
    ):
        super(FSMN, self).__init__()
        # self.input_dim = input_dim
        # self.input_affine_dim = input_affine_dim
        # self.fsmn_layers = fsmn_layers
        # self.linear_dim = linear_dim
        # self.proj_dim = proj_dim
        # self.output_affine_dim = output_affine_dim
        # self.output_dim = output_dim
        #
        # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
        # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
        # self.relu = RectifiedLinear(linear_dim, linear_dim)
        # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
        #                         range(fsmn_layers)])
        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        # self.softmax = nn.Softmax(dim=-1)
        self.in_linear1 = model.in_linear1
        self.in_linear2 = model.in_linear2
        self.relu = model.relu
        # self.fsmn = model.fsmn
        self.out_linear1 = model.out_linear1
        self.out_linear2 = model.out_linear2
        self.softmax = model.softmax
        self.fsmn = model.fsmn
        for i, d in enumerate(model.fsmn):
            if isinstance(d, BasicBlock):
                self.fsmn[i] = BasicBlock_export(d)
    def fuse_modules(self):
        pass
    def forward(
            self,
            input: torch.Tensor,
            *args,
    ):
        """
        Args:
            input (torch.Tensor): Input tensor (B, T, D)
            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
        """
        x = self.in_linear1(input)
        x = self.in_linear2(x)
        x = self.relu(x)
        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
        out_caches = list()
        for i, d in enumerate(self.fsmn):
            in_cache = args[i]
            x, out_cache = d(x, in_cache)
            out_caches.append(out_cache)
        x = self.out_linear1(x)
        x = self.out_linear2(x)
        x = self.softmax(x)
        return x, out_caches
'''
one deep fsmn layer
dimproj:                projection dimension, input and output dimension of memory blocks
dimlinear:              dimension of mapping layer
lorder:                 left order
rorder:                 right order
lstride:                left stride
rstride:                right stride
'''
class DFSMN(nn.Module):
    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
        super(DFSMN, self).__init__()
        self.lorder = lorder
        self.rorder = rorder
        self.lstride = lstride
        self.rstride = rstride
        self.expand = AffineTransform(dimproj, dimlinear)
        self.shrink = LinearTransform(dimlinear, dimproj)
        self.conv_left = nn.Conv2d(
            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
        if rorder > 0:
            self.conv_right = nn.Conv2d(
                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
        else:
            self.conv_right = None
    def forward(self, input):
        f1 = F.relu(self.expand(input))
        p1 = self.shrink(f1)
        x = torch.unsqueeze(p1, 1)
        x_per = x.permute(0, 3, 2, 1)
        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
        if self.conv_right is not None:
            y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
            y_right = y_right[:, :, self.rstride:, :]
            out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
        else:
            out = x_per + self.conv_left(y_left)
        out1 = out.permute(0, 3, 2, 1)
        output = input + out1.squeeze(1)
        return output
'''
build stacked dfsmn layers
'''
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
    repeats = [
        nn.Sequential(
            DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
        for i in range(fsmn_layers)
    ]
    return nn.Sequential(*repeats)
if __name__ == '__main__':
    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
    print(fsmn)
    num_params = sum(p.numel() for p in fsmn.parameters())
    print('the number of model params: {}'.format(num_params))
    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
    y, _ = fsmn(x)  # batch-size * time * dim
    print('input shape: {}'.format(x.shape))
    print('output shape: {}'.format(y.shape))
    print(fsmn.to_kaldi_net())
funasr/export/models/encoder/sanm_encoder.py
@@ -9,6 +9,7 @@
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
class SANMEncoder(nn.Module):
    def __init__(
        self,
@@ -107,3 +108,106 @@
            }
        }
class SANMVadEncoder(nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name='encoder',
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        self.model = model
        self.feats_dim = feats_dim
        self._output_size = model._output_size
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        if hasattr(model, 'encoders0'):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                    d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
                if isinstance(d.feed_forward, PositionwiseFeedForward):
                    d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
                self.model.encoders0[i] = EncoderLayerSANM_export(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
            if isinstance(d.feed_forward, PositionwiseFeedForward):
                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
            self.model.encoders[i] = EncoderLayerSANM_export(d)
        self.model_name = model_name
        self.num_heads = model.encoders[0].self_attn.h
        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
    def prepare_mask(self, mask, sub_masks):
        mask_3d_btd = mask[:, :, None]
        mask_4d_bhlt = (1 - sub_masks) * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                vad_masks: torch.Tensor,
                sub_masks: torch.Tensor,
                ):
        speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        vad_masks = self.prepare_mask(mask, vad_masks)
        mask = self.prepare_mask(mask, sub_masks)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        # encoder_outs = self.model.encoders(xs_pad, mask)
        for layer_idx, encoder_layer in enumerate(self.model.encoders):
            if layer_idx == len(self.model.encoders) - 1:
                mask = vad_masks
            encoder_outs = encoder_layer(xs_pad, mask)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size
    # def get_dummy_inputs(self):
    #     feats = torch.randn(1, 100, self.feats_dim)
    #     return (feats)
    #
    # def get_input_names(self):
    #     return ['feats']
    #
    # def get_output_names(self):
    #     return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
    #
    # def get_dynamic_axes(self):
    #     return {
    #         'feats': {
    #             1: 'feats_length'
    #         },
    #         'encoder_out': {
    #             1: 'enc_out_length'
    #         },
    #         'predictor_weight': {
    #             1: 'pre_out_length'
    #         }
    #
    #     }
funasr/export/test/__init__.py
funasr/export/test/test_onnx.py
funasr/export/test/test_onnx_punc.py
New file
@@ -0,0 +1,18 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
    onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
    sess = onnxruntime.InferenceSession(onnx_path)
    input_name = [nd.name for nd in sess.get_inputs()]
    output_name = [nd.name for nd in sess.get_outputs()]
    def _get_feed_dict(text_length):
        return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
    def _run(feed_dict):
        output = sess.run(output_name, input_feed=feed_dict)
        for name, value in zip(output_name, output):
            print('{}: {}'.format(name, value))
    _run(_get_feed_dict(10))
funasr/export/test/test_onnx_punc_vadrealtime.py
New file
@@ -0,0 +1,22 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
    onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
    sess = onnxruntime.InferenceSession(onnx_path)
    input_name = [nd.name for nd in sess.get_inputs()]
    output_name = [nd.name for nd in sess.get_outputs()]
    def _get_feed_dict(text_length):
        return {'inputs': np.ones((1, text_length), dtype=np.int64),
                'text_lengths': np.array([text_length,], dtype=np.int32),
                'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
                'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
                }
    def _run(feed_dict):
        output = sess.run(output_name, input_feed=feed_dict)
        for name, value in zip(output_name, output):
            print('{}: {}'.format(name, value))
    _run(_get_feed_dict(10))
funasr/export/test/test_onnx_vad.py
New file
@@ -0,0 +1,26 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
    onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
    sess = onnxruntime.InferenceSession(onnx_path)
    input_name = [nd.name for nd in sess.get_inputs()]
    output_name = [nd.name for nd in sess.get_outputs()]
    def _get_feed_dict(feats_length):
        return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
                'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
                'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
                'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
                'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
                }
    def _run(feed_dict):
        output = sess.run(output_name, input_feed=feed_dict)
        for name, value in zip(output_name, output):
            print('{}: {}'.format(name, value.shape))
    _run(_get_feed_dict(100))
    _run(_get_feed_dict(200))
funasr/export/test/test_torchscripts.py
funasr/lm/abs_model.py
@@ -5,7 +5,17 @@
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
    """The abstract LM class
@@ -27,3 +37,122 @@
        self, input: torch.Tensor, hidden: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError
class LanguageModel(AbsESPnetModel):
    def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
        assert check_argument_types()
        super().__init__()
        self.lm = lm
        self.sos = 1
        self.eos = 2
        # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
        self.ignore_id = ignore_id
    def nll(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        max_length: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll)
        Normally, this function is called in batchify_nll.
        Args:
            text: (Batch, Length)
            text_lengths: (Batch,)
            max_lengths: int
        """
        batch_size = text.size(0)
        # For data parallel
        if max_length is None:
            text = text[:, : text_lengths.max()]
        else:
            text = text[:, :max_length]
        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
        # text: (Batch, Length) -> x, y: (Batch, Length + 1)
        x = F.pad(text, [1, 0], "constant", self.sos)
        t = F.pad(text, [0, 1], "constant", self.ignore_id)
        for i, l in enumerate(text_lengths):
            t[i, l] = self.eos
        x_lengths = text_lengths + 1
        # 2. Forward Language model
        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
        y, _ = self.lm(x, None)
        # 3. Calc negative log likelihood
        # nll: (BxL,)
        nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
        # nll: (BxL,) -> (BxL,)
        if max_length is None:
            nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
        else:
            nll.masked_fill_(
                make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
                0.0,
            )
        # nll: (BxL,) -> (B, L)
        nll = nll.view(batch_size, -1)
        return nll, x_lengths
    def batchify_nll(
        self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll) from transformer language model
        To avoid OOM, this fuction seperate the input into batches.
        Then call nll for each batch and combine and return results.
        Args:
            text: (Batch, Length)
            text_lengths: (Batch,)
            batch_size: int, samples each batch contain when computing nll,
                        you may change this to avoid OOM or increase
        """
        total_num = text.size(0)
        if total_num <= batch_size:
            nll, x_lengths = self.nll(text, text_lengths)
        else:
            nlls = []
            x_lengths = []
            max_length = text_lengths.max()
            start_idx = 0
            while True:
                end_idx = min(start_idx + batch_size, total_num)
                batch_text = text[start_idx:end_idx, :]
                batch_text_lengths = text_lengths[start_idx:end_idx]
                # batch_nll: [B * T]
                batch_nll, batch_x_lengths = self.nll(
                    batch_text, batch_text_lengths, max_length=max_length
                )
                nlls.append(batch_nll)
                x_lengths.append(batch_x_lengths)
                start_idx = end_idx
                if start_idx == total_num:
                    break
            nll = torch.cat(nlls)
            x_lengths = torch.cat(x_lengths)
        assert nll.size(0) == total_num
        assert x_lengths.size(0) == total_num
        return nll, x_lengths
    def forward(
        self, text: torch.Tensor, text_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        nll, y_lengths = self.nll(text, text_lengths)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    def collect_feats(
        self, text: torch.Tensor, text_lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        return {}
funasr/lm/espnet_model.py
File was deleted
funasr/models/decoder/contextual_decoder.py
@@ -102,7 +102,7 @@
class ContextualParaformerDecoder(ParaformerSANMDecoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
funasr/models/decoder/sanm_decoder.py
@@ -104,7 +104,6 @@
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -152,7 +151,7 @@
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
@@ -400,7 +399,7 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            c = cache[i]
            x, tgt_mask, memory, memory_mask, c_ret = decoder(
            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                x, tgt_mask, memory, memory_mask, cache=c
            )
            new_cache.append(c_ret)
@@ -410,13 +409,13 @@
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                c = cache[j]
                x, tgt_mask, memory, memory_mask, c_ret = decoder(
                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                    x, tgt_mask, memory, memory_mask, cache=c
                )
                new_cache.append(c_ret)
        for decoder in self.decoders3:
            x, tgt_mask, memory, memory_mask, _ = decoder(
            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                x, tgt_mask, memory, None, cache=None
            )
@@ -813,7 +812,7 @@
class ParaformerSANMDecoder(BaseTransformerDecoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
@@ -1077,7 +1076,7 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            c = cache[i]
            x, tgt_mask, memory, memory_mask, c_ret = decoder(
            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(c_ret)
@@ -1087,14 +1086,14 @@
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                c = cache[j]
                x, tgt_mask, memory, memory_mask, c_ret = decoder(
                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                    x, tgt_mask, memory, None, cache=c
                )
                new_cache.append(c_ret)
        for decoder in self.decoders3:
            x, tgt_mask, memory, memory_mask, _ = decoder(
            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                x, tgt_mask, memory, None, cache=None
            )
funasr/models/decoder/transformer_decoder.py
@@ -405,7 +405,7 @@
class ParaformerDecoderSAN(BaseTransformerDecoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
funasr/models/e2e_asr_mfcca.py
@@ -36,7 +36,11 @@
import random
import math
class MFCCA(AbsESPnetModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    """
    Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
    MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
    https://arxiv.org/abs/2210.05265
    """
    def __init__(
        self,
funasr/models/e2e_asr_paraformer.py
@@ -44,7 +44,7 @@
class Paraformer(AbsESPnetModel):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
@@ -612,7 +612,7 @@
class ParaformerBert(Paraformer):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
    """
funasr/models/e2e_diar_sond.py
@@ -36,8 +36,12 @@
class DiarSondModel(AbsESPnetModel):
    """Speaker overlap-aware neural diarization model
    reference: https://arxiv.org/abs/2211.10243
    """
    Author: Speech Lab, Alibaba Group, China
    SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
    https://arxiv.org/abs/2211.10243
    TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
    https://arxiv.org/abs/2303.05397
    """
    def __init__(
funasr/models/e2e_sv.py
@@ -1,3 +1,7 @@
"""
Author: Speech Lab, Alibaba Group, China
"""
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
funasr/models/e2e_tp.py
@@ -32,7 +32,7 @@
class TimestampPredictor(AbsESPnetModel):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    """
    def __init__(
funasr/models/e2e_uni_asr.py
@@ -40,7 +40,7 @@
class UniASR(AbsESPnetModel):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    """
    def __init__(
funasr/models/e2e_vad.py
@@ -35,6 +35,11 @@
class VADXOptions:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(
            self,
            sample_rate: int = 16000,
@@ -99,6 +104,11 @@
class E2EVadSpeechBufWithDoa(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
@@ -117,6 +127,11 @@
class E2EVadFrameProb(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
@@ -126,6 +141,11 @@
class WindowDetector(object):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
                 speech_to_sil_time: int, frame_size_ms: int):
        self.window_size_ms = window_size_ms
@@ -192,7 +212,12 @@
class E2EVadModel(nn.Module):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
        super(E2EVadModel, self).__init__()
        self.vad_opts = VADXOptions(**vad_post_args)
        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -229,6 +254,7 @@
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
        self.frontend = frontend
    def AllResetDetection(self):
        self.is_final = False
@@ -459,8 +485,8 @@
            segment_batch = []
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
                        i].contain_seg_end_point:
                    if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
                        i].contain_seg_end_point):
                        continue
                    segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                    segment_batch.append(segment)
@@ -477,8 +503,9 @@
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, in_cache)
        self.ComputeDecibel()
        if not is_final:
            self.DetectCommonFrames()
        else:
funasr/models/encoder/opennmt_encoders/conv_encoder.py
@@ -67,7 +67,7 @@
class ConvEncoder(AbsEncoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Convolution encoder in OpenNMT framework
    """
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
@@ -117,7 +117,7 @@
class SelfAttentionEncoder(AbsEncoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Self attention encoder in OpenNMT framework
    """
funasr/models/encoder/resnet34_encoder.py
@@ -406,6 +406,12 @@
            tf2torch_tensor_name_prefix_torch="encoder",
            tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
    ):
        """
        Author: Speech Lab, Alibaba Group, China
        SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
        https://arxiv.org/abs/2211.10243
        """
        super(ResNet34Diar, self).__init__(
            input_size,
            use_head_conv=use_head_conv,
@@ -633,6 +639,12 @@
            tf2torch_tensor_name_prefix_torch="encoder",
            tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
    ):
        """
        Author: Speech Lab, Alibaba Group, China
        TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
        https://arxiv.org/abs/2303.05397
        """
        super(ResNet34SpL2RegDiar, self).__init__(
            input_size,
            use_head_conv=use_head_conv,
funasr/models/encoder/sanm_encoder.py
@@ -10,7 +10,7 @@
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
@@ -27,7 +27,7 @@
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
    def __init__(
@@ -117,7 +117,7 @@
class SANMEncoder(AbsEncoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
@@ -549,7 +549,7 @@
class SANMEncoderChunkOpt(AbsEncoder):
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
@@ -958,3 +958,231 @@
                                                                                      var_dict_tf[name_tf].shape))
    
        return var_dict_torch_update
class SANMVadEncoder(AbsEncoder):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    """
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "conv2d",
        pos_enc_class=SinusoidalPositionEncoder,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
        interctc_layer_idx: List[int] = [],
        interctc_use_conditioning: bool = False,
        kernel_size : int = 11,
        sanm_shfit : int = 0,
        selfattention_layer_type: str = "sanm",
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size
        if input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(input_size, output_size),
                torch.nn.LayerNorm(output_size),
                torch.nn.Dropout(dropout_rate),
                torch.nn.ReLU(),
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d":
            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
        elif input_layer == "conv2d2":
            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
        elif input_layer == "conv2d6":
            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
        elif input_layer == "conv2d8":
            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
        elif input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
                SinusoidalPositionEncoder(),
            )
        elif input_layer is None:
            if input_size == output_size:
                self.embed = None
            else:
                self.embed = torch.nn.Linear(input_size, output_size)
        elif input_layer == "pe":
            self.embed = SinusoidalPositionEncoder()
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        self.normalize_before = normalize_before
        if positionwise_layer_type == "linear":
            positionwise_layer = PositionwiseFeedForward
            positionwise_layer_args = (
                output_size,
                linear_units,
                dropout_rate,
            )
        elif positionwise_layer_type == "conv1d":
            positionwise_layer = MultiLayeredConv1d
            positionwise_layer_args = (
                output_size,
                linear_units,
                positionwise_conv_kernel_size,
                dropout_rate,
            )
        elif positionwise_layer_type == "conv1d-linear":
            positionwise_layer = Conv1dLinear
            positionwise_layer_args = (
                output_size,
                linear_units,
                positionwise_conv_kernel_size,
                dropout_rate,
            )
        else:
            raise NotImplementedError("Support only linear or conv1d.")
        if selfattention_layer_type == "selfattn":
            encoder_selfattn_layer = MultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
        elif selfattention_layer_type == "sanm":
            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
            encoder_selfattn_layer_args0 = (
                attention_heads,
                input_size,
                output_size,
                attention_dropout_rate,
                kernel_size,
                sanm_shfit,
            )
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                output_size,
                attention_dropout_rate,
                kernel_size,
                sanm_shfit,
            )
        self.encoders0 = repeat(
            1,
            lambda lnum: EncoderLayerSANM(
                input_size,
                output_size,
                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        self.encoders = repeat(
            num_blocks-1,
            lambda lnum: EncoderLayerSANM(
                output_size,
                output_size,
                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        if self.normalize_before:
            self.after_norm = LayerNorm(output_size)
        self.interctc_layer_idx = interctc_layer_idx
        if len(interctc_layer_idx) > 0:
            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
        self.interctc_use_conditioning = interctc_use_conditioning
        self.conditioning_layer = None
        self.dropout = nn.Dropout(dropout_rate)
    def output_size(self) -> int:
        return self._output_size
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        vad_indexes: torch.Tensor,
        prev_states: torch.Tensor = None,
        ctc: CTC = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.
        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
        no_future_masks = masks & sub_masks
        xs_pad *= self.output_size()**0.5
        if self.embed is None:
            xs_pad = xs_pad
        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
                    f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)
        # xs_pad = self.dropout(xs_pad)
        mask_tup0 = [masks, no_future_masks]
        encoder_outs = self.encoders0(xs_pad, mask_tup0)
        xs_pad, _ = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        for layer_idx, encoder_layer in enumerate(self.encoders):
                if layer_idx + 1 == len(self.encoders):
                    # This is last layer.
                    coner_mask = torch.ones(masks.size(0),
                                            masks.size(-1),
                                            masks.size(-1),
                                            device=xs_pad.device,
                                            dtype=torch.bool)
                    for word_index, length in enumerate(ilens):
                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
                                                                vad_indexes[word_index],
                                                                device=xs_pad.device)
                    layer_mask = masks & coner_mask
                else:
                    layer_mask = no_future_masks
                mask_tup1 = [masks, layer_mask]
                encoder_outs = encoder_layer(xs_pad, mask_tup1)
                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
        olens = masks.squeeze(1).sum(1)
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
funasr/models/frontend/wav_frontend.py
@@ -38,7 +38,7 @@
    return cmvn
def apply_cmvn(inputs, cmvn_file):  # noqa
def apply_cmvn(inputs, cmvn):  # noqa
    """
    Apply CMVN with mvn data
    """
@@ -47,7 +47,6 @@
    dtype = inputs.dtype
    frame, dim = inputs.shape
    cmvn = load_cmvn(cmvn_file)
    means = np.tile(cmvn[0:1, :dim], (frame, 1))
    vars = np.tile(cmvn[1:2, :dim], (frame, 1))
    inputs += torch.from_numpy(means).type(dtype).to(device)
@@ -111,6 +110,7 @@
        self.dither = dither
        self.snip_edges = snip_edges
        self.upsacle_samples = upsacle_samples
        self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
    def output_size(self) -> int:
        return self.n_mels * self.lfr_m
@@ -140,8 +140,8 @@
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn_file is not None:
                mat = apply_cmvn(mat, self.cmvn_file)
            if self.cmvn is not None:
                mat = apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
@@ -194,8 +194,8 @@
            mat = input[i, :input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn_file is not None:
                mat = apply_cmvn(mat, self.cmvn_file)
            if self.cmvn is not None:
                mat = apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
funasr/models/predictor/cif.py
@@ -234,6 +234,7 @@
        last_fire_place = len_time - 1
        last_fire_remainds = 0.0
        pre_alphas_length = 0
        last_fire = False
 
        mask_chunk_peak_predictor = None
        if cache is not None:
@@ -251,10 +252,15 @@
            if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
                last_fire_place = len_time - 1 - i
                last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
                last_fire = True
                break
        last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
        cache["cif_hidden"] = hidden[:, last_fire_place:, :]
        cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
        if last_fire:
           last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
           cache["cif_hidden"] = hidden[:, last_fire_place:, :]
           cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
        else:
           cache["cif_hidden"] = hidden
           cache["cif_alphas"] = alphas
        token_num_int = token_num.floor().type(torch.int32).item()
        return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
funasr/models/target_delay_transformer.py
File was renamed from funasr/punctuation/target_delay_transformer.py
@@ -5,16 +5,19 @@
import torch
import torch.nn as nn
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(
        self,
        vocab_size: int,
funasr/models/vad_realtime_transformer.py
File was renamed from funasr/punctuation/vad_realtime_transformer.py
@@ -6,12 +6,16 @@
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(
        self,
        vocab_size: int,
funasr/modules/streaming_utils/chunk_utilis.py
@@ -11,7 +11,7 @@
class overlap_chunk():
    """
    author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
funasr/punctuation/abs_model.py
File was deleted
funasr/punctuation/sanm_encoder.py
File was deleted
funasr/punctuation/text_preprocessor.py
File was deleted
funasr/runtime/grpc/CMakeLists.txt
@@ -74,7 +74,7 @@
    "${_target}.cc")
  target_link_libraries(${_target}
    rg_grpc_proto
    rapidasr
    funasr
    ${EXTRA_LIBS}
    ${_REFLECTION}
    ${_GRPC_GRPCPP}
funasr/runtime/grpc/Readme.md
@@ -53,6 +53,68 @@
python grpc_main_client_mic.py  --host $server_ip --port 10108
```
The `grpc_main_client_mic.py` follows the [original design] (https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc#workflow-in-desgin) by sending audio_data with chunks. If you want to send audio_data in one request, here is an example:
```
# go to ../python/grpc to find this package
import paraformer_pb2
class RecognizeStub:
    def __init__(self, channel):
        self.Recognize = channel.stream_stream(
                '/paraformer.ASR/Recognize',
                request_serializer=paraformer_pb2.Request.SerializeToString,
                response_deserializer=paraformer_pb2.Response.FromString,
                )
async def send(channel, data, speaking, isEnd):
    stub = RecognizeStub(channel)
    req = paraformer_pb2.Request()
    if data:
        req.audio_data = data
    req.user = 'zz'
    req.language = 'zh-CN'
    req.speaking = speaking
    req.isEnd = isEnd
    q = queue.SimpleQueue()
    q.put(req)
    return stub.Recognize(iter(q.get, None))
# send the audio data once
async def grpc_rec(data, grpc_uri):
    with grpc.insecure_channel(grpc_uri) as channel:
        b = time.time()
        response = await send(channel, data, False, False)
        resp = response.next()
        text = ''
        if 'decoding' == resp.action:
            resp = response.next()
            if 'finish' == resp.action:
                text = json.loads(resp.sentence)['text']
        response = await send(channel, None, False, True)
        return {
                'text': text,
                'time': time.time() - b,
                }
async def test():
    # fc = FunAsrGrpcClient('127.0.0.1', 9900)
    # t = await fc.rec(wav.tobytes())
    # print(t)
    wav, _ = sf.read('z-10s.wav', dtype='int16')
    uri = '127.0.0.1:9900'
    res = await grpc_rec(wav.tobytes(), uri)
    print(res)
if __name__ == '__main__':
    asyncio.run(test())
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
funasr/runtime/grpc/paraformer_server.cc
@@ -15,7 +15,6 @@
#include "paraformer.grpc.pb.h"
#include "paraformer_server.h"
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
@@ -24,37 +23,14 @@
using grpc::ServerWriter;
using grpc::Status;
using paraformer::Request;
using paraformer::Response;
using paraformer::ASR;
ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
    AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
    AsrHanlde=FunASRInit(model_path, thread_num, quantize);
    std::cout << "ASRServicer init" << std::endl;
    init_flag = 0;
}
void ASRServicer::clear_states(const std::string& user) {
    clear_buffers(user);
    clear_transcriptions(user);
}
void ASRServicer::clear_buffers(const std::string& user) {
    if (client_buffers.count(user)) {
        client_buffers.erase(user);
    }
}
void ASRServicer::clear_transcriptions(const std::string& user) {
    if (client_transcription.count(user)) {
        client_transcription.erase(user);
    }
}
void ASRServicer::disconnect(const std::string& user) {
    clear_states(user);
    std::cout << "Disconnecting user: " << user << std::endl;
}
grpc::Status ASRServicer::Recognize(
@@ -62,10 +38,20 @@
    grpc::ServerReaderWriter<Response, Request>* stream) {
    Request req;
    std::unordered_map<std::string, std::string> client_buffers;
    std::unordered_map<std::string, std::string> client_transcription;
    while (stream->Read(&req)) {
        if (req.isend()) {
            std::cout << "asr end" << std::endl;
            disconnect(req.user());
            // disconnect
            if (client_buffers.count(req.user())) {
                client_buffers.erase(req.user());
            }
            if (client_transcription.count(req.user())) {
                client_transcription.erase(req.user());
            }
            Response res;
            res.set_sentence(
                R"({"success": true, "detail": "asr end"})"
@@ -88,7 +74,7 @@
            res.set_language(req.language());
            stream->Write(res);
        } else if (!req.speaking()) {
            if (client_buffers.count(req.user()) == 0) {
            if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
                Response res;
                res.set_sentence(
                    R"({"success": true, "detail": "waiting_for_voice"})"
@@ -99,14 +85,24 @@
                stream->Write(res);
            }else {
                auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
                std::string tmp_data = this->client_buffers[req.user()];
                this->clear_states(req.user());
                if (req.audio_data().size() > 0) {
                  auto& buf = client_buffers[req.user()];
                  buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
                }
                std::string tmp_data = client_buffers[req.user()];
                // clear_states
                if (client_buffers.count(req.user())) {
                    client_buffers.erase(req.user());
                }
                if (client_transcription.count(req.user())) {
                    client_transcription.erase(req.user());
                }
                Response res;
                res.set_sentence(
                    R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
                );
        int data_len_int = tmp_data.length();
                int data_len_int = tmp_data.length();
                std::string data_len = std::to_string(data_len_int);
                std::stringstream ss;
                ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")"  << R"("})";
@@ -129,18 +125,15 @@
                    res.set_user(req.user());
                    res.set_action("finish");
                    res.set_language(req.language());
                    stream->Write(res);
                }
                else {
                    RPASR_RESULT Result= RapidAsrRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
                    std::string asr_result = ((RPASR_RECOG_RESULT*)Result)->msg;
                    FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
                    std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
                    auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
                    std::string delay_str = std::to_string(end_time - begin_time);
                    std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
                    Response res;
                    std::stringstream ss;
@@ -150,8 +143,7 @@
                    res.set_user(req.user());
                    res.set_action("finish");
                    res.set_language(req.language());
                    stream->Write(res);
                }
            }
@@ -165,10 +157,9 @@
            res.set_language(req.language());
            stream->Write(res);
        }
    }
    }
    return Status::OK;
}
void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
    std::string server_address;
funasr/runtime/grpc/paraformer_server.h
@@ -15,7 +15,7 @@
#include <chrono>
#include "paraformer.grpc.pb.h"
#include "librapidasrapi.h"
#include "libfunasrapi.h"
using grpc::Server;
@@ -35,22 +35,16 @@
{
    std::string msg;
    float  snippet_time;
}RPASR_RECOG_RESULT;
}FUNASR_RECOG_RESULT;
class ASRServicer final : public ASR::Service {
  private:
    int init_flag;
    std::unordered_map<std::string, std::string> client_buffers;
    std::unordered_map<std::string, std::string> client_transcription;
  public:
    ASRServicer(const char* model_path, int thread_num, bool quantize);
    void clear_states(const std::string& user);
    void clear_buffers(const std::string& user);
    void clear_transcriptions(const std::string& user);
    void disconnect(const std::string& user);
    grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
    RPASR_HANDLE AsrHanlde;
    FUNASR_HANDLE AsrHanlde;
    
};
funasr/runtime/onnxruntime/CMakeLists.txt
@@ -2,24 +2,27 @@
project(FunASRonnx)
set(CMAKE_CXX_STANDARD 11)
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
include(TestBigEndian)
test_big_endian(BIG_ENDIAN)
if(BIG_ENDIAN)
    message("Big endian system")
else()
    message("Little endian system")
endif()
# for onnxruntime
IF(WIN32)
    if(CMAKE_CL_64)
        link_directories(${ONNXRUNTIME_DIR}\\lib)
    else()
        add_definitions(-D_WIN_X86)
    endif()
ELSE()
link_directories(${ONNXRUNTIME_DIR}/lib)
    link_directories(${ONNXRUNTIME_DIR}/lib)
endif()
add_subdirectory("./third_party/yaml-cpp")
funasr/runtime/onnxruntime/include/Audio.h
@@ -6,6 +6,13 @@
#include <queue>
#include <stdint.h>
#ifndef model_sample_rate
#define model_sample_rate 16000
#endif
#ifndef WAV_HEADER_SIZE
#define WAV_HEADER_SIZE 44
#endif
using namespace std;
class AudioFrame {
@@ -32,7 +39,6 @@
    int16_t *speech_buff;
    int speech_len;
    int speech_align_len;
    int16_t sample_rate;
    int offset;
    float align_size;
    int data_type;
@@ -43,10 +49,11 @@
    Audio(int data_type, int size);
    ~Audio();
    void disp();
    bool loadwav(const char* filename);
    bool loadwav(const char* buf, int nLen);
    bool loadpcmwav(const char* buf, int nFileLen);
    bool loadpcmwav(const char* filename);
    bool loadwav(const char* filename, int32_t* sampling_rate);
    void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
    bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
    bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
    bool loadpcmwav(const char* filename, int32_t* sampling_rate);
    int fetch_chunck(float *&dout, int len);
    int fetch(float *&dout, int &len, int &flag);
    void padding();
funasr/runtime/onnxruntime/include/libfunasrapi.h
New file
@@ -0,0 +1,77 @@
#pragma once
#ifdef WIN32
#ifdef _FUNASR_API_EXPORT
#define  _FUNASRAPI __declspec(dllexport)
#else
#define  _FUNASRAPI __declspec(dllimport)
#endif
#else
#define _FUNASRAPI
#endif
#ifndef _WIN32
#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
#else
#define FUNASR_CALLBCK_PREFIX __stdcall
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef void* FUNASR_HANDLE;
typedef void* FUNASR_RESULT;
typedef unsigned char FUNASR_BOOL;
#define FUNASR_TRUE 1
#define FUNASR_FALSE 0
#define QM_DEFAULT_THREAD_NUM  4
typedef enum
{
 RASR_NONE=-1,
 RASRM_CTC_GREEDY_SEARCH=0,
 RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
 RASRM_ATTENSION_RESCORING = 2,
}FUNASR_MODE;
typedef enum {
    FUNASR_MODEL_PADDLE = 0,
    FUNASR_MODEL_PADDLE_2 = 1,
    FUNASR_MODEL_K2 = 2,
    FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
// APIs for qmasr
_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize);
// if not give a fnCallback ,it should be NULL
_FUNASRAPI FUNASR_RESULT    FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI const char*    FunASRGetResult(FUNASR_RESULT Result,int nIndex);
_FUNASRAPI const int        FunASRGetRetNumber(FUNASR_RESULT Result);
_FUNASRAPI void            FunASRFreeResult(FUNASR_RESULT Result);
_FUNASRAPI void            FunASRUninit(FUNASR_HANDLE Handle);
_FUNASRAPI const float    FunASRGetRetSnippetTime(FUNASR_RESULT Result);
#ifdef __cplusplus
}
#endif
funasr/runtime/onnxruntime/include/librapidasrapi.h
File was deleted
funasr/runtime/onnxruntime/readme.md
@@ -1,83 +1,70 @@
## 快速使用
### Windows
 安装Vs2022 打开cpp_onnx目录下的cmake工程,直接 build即可。 本仓库已经准备好所有相关依赖库。
 Windows下已经预置fftw3及onnxruntime库
### Linux
See the bottom of this page: Building Guidance
###  运行程序
tester  /path/to/models_dir /path/to/wave_file quantize(true or false)
例如: tester /data/models  /data/test.wav false
/data/models 需要包括如下三个文件: config.yaml, am.mvn, model.onnx(or model_quant.onnx)
## 支持平台
- Windows
- Linux/Unix
## 依赖
- fftw3
- openblas
- onnxruntime
## 导出onnx格式模型文件
安装 modelscope与FunASR,依赖:torch,torchaudio,安装过程[详细参考文档](https://github.com/alibaba-damo-academy/FunASR/wiki)
## Demo
```shell
pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install --editable ./
tester /path/models_dir /path/wave_file quantize(true or false)
```
导出onnx模型,[详见](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export),参考示例,从modelscope中模型导出:
The structure of /path/models_dir
```
config.yaml, am.mvn, model.onnx(or model_quant.onnx)
```
## Steps
### Export onnx
#### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
```shell
pip3 install torch torchaudio
pip install -U modelscope
pip install -U funasr
```
#### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
## Building Guidance for Linux/Unix
### Building for Linux/Unix
```
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
mkdir build
cd build
#### Download onnxruntime
```shell
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
# here we get a copy of onnxruntime for linux 64
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
# ls
# onnxruntime-linux-x64-1.14.0  onnxruntime-linux-x64-1.14.0.tgz
```
#install fftw3-dev
ubuntu: apt install libfftw3-dev
centos: yum install fftw fftw-devel
#### Install fftw3
```shell
sudo apt install libfftw3-dev #ubuntu
# sudo yum install fftw fftw-devel #centos
```
#install openblas
bash ./third_party/install_openblas.sh
#### Install openblas
```shell
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
```
# build
 cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
 make
#### Build runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
mkdir build && cd build
cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
make
```
 # then in the subfolder tester of current direcotry, you will see a program, tester
````
### The structure of a qualified onnxruntime package.
#### The structure of a qualified onnxruntime package.
```
onnxruntime_xxx
├───include
└───lib
```
## 注意
本程序只支持 采样率16000hz, 位深16bit的 **单声道** 音频。
### Building for Windows
Ref to win/
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
funasr/runtime/onnxruntime/src/Audio.cpp
@@ -3,10 +3,95 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <fstream>
#include <assert.h>
#include "Audio.h"
#include "precomp.h"
using namespace std;
// see http://soundfile.sapp.org/doc/WaveFormat/
// Note: We assume little endian here
struct WaveHeader {
  bool Validate() const {
    //                 F F I R
    if (chunk_id != 0x46464952) {
      printf("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
      return false;
    }
    //               E V A W
    if (format != 0x45564157) {
      printf("Expected format WAVE. Given: 0x%08x\n", format);
      return false;
    }
    if (subchunk1_id != 0x20746d66) {
      printf("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
                       subchunk1_id);
      return false;
    }
    if (subchunk1_size != 16) {  // 16 for PCM
      printf("Expected subchunk1_size 16. Given: %d\n",
                       subchunk1_size);
      return false;
    }
    if (audio_format != 1) {  // 1 for PCM
      printf("Expected audio_format 1. Given: %d\n", audio_format);
      return false;
    }
    if (num_channels != 1) {  // we support only single channel for now
      printf("Expected single channel. Given: %d\n", num_channels);
      return false;
    }
    if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
      return false;
    }
    if (block_align != (num_channels * bits_per_sample / 8)) {
      return false;
    }
    if (bits_per_sample != 16) {  // we support only 16 bits per sample
      printf("Expected bits_per_sample 16. Given: %d\n",
                       bits_per_sample);
      return false;
    }
    return true;
  }
  // See https://en.wikipedia.org/wiki/WAV#Metadata and
  // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
  void SeekToDataChunk(std::istream &is) {
    //                              a t a d
    while (is && subchunk2_id != 0x61746164) {
      // const char *p = reinterpret_cast<const char *>(&subchunk2_id);
      // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
      //        p[1], p[2], p[3], subchunk2_size);
      is.seekg(subchunk2_size, std::istream::cur);
      is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t));
      is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t));
    }
  }
  int32_t chunk_id;
  int32_t chunk_size;
  int32_t format;
  int32_t subchunk1_id;
  int32_t subchunk1_size;
  int16_t audio_format;
  int16_t num_channels;
  int32_t sample_rate;
  int32_t byte_rate;
  int16_t block_align;
  int16_t bits_per_sample;
  int32_t subchunk2_id;    // a tag of this chunk
  int32_t subchunk2_size;  // size of subchunk2
};
static_assert(sizeof(WaveHeader) == WAV_HEADER_SIZE, "");
class AudioWindow {
  private:
@@ -56,7 +141,7 @@
    float frame_length = 400;
    float frame_shift = 160;
    float num_new_samples =
        ceil((num_samples - 400) / frame_shift) * frame_shift + frame_length;
        ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
    end = start + num_new_samples;
    len = (int)num_new_samples;
@@ -111,62 +196,95 @@
void Audio::disp()
{
    printf("Audio time is %f s. len is %d\n", (float)speech_len / 16000,
    printf("Audio time is %f s. len is %d\n", (float)speech_len / model_sample_rate,
           speech_len);
}
float Audio::get_time_len()
{
    return (float)speech_len / 16000;
       //speech_len);
    return (float)speech_len / model_sample_rate;
}
bool Audio::loadwav(const char *filename)
void Audio::wavResample(int32_t sampling_rate, const float *waveform,
                          int32_t n)
{
    printf(
          "Creating a resampler:\n"
          "   in_sample_rate: %d\n"
          "   output_sample_rate: %d\n",
          sampling_rate, static_cast<int32_t>(model_sample_rate));
    float min_freq =
        std::min<int32_t>(sampling_rate, model_sample_rate);
    float lowpass_cutoff = 0.99 * 0.5 * min_freq;
    int32_t lowpass_filter_width = 6;
    //FIXME
    //auto resampler = new LinearResample(
    //      sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
    auto resampler = std::make_unique<LinearResample>(
          sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
    std::vector<float> samples;
    resampler->Resample(waveform, n, true, &samples);
    //reset speech_data
    speech_len = samples.size();
    if (speech_data != NULL) {
        free(speech_data);
    }
    speech_data = (float*)malloc(sizeof(float) * speech_len);
    memset(speech_data, 0, sizeof(float) * speech_len);
    copy(samples.begin(), samples.end(), speech_data);
}
bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
{
    WaveHeader header;
    if (speech_data != NULL) {
        free(speech_data);
    }
    if (speech_buff != NULL) {
        free(speech_buff);
    }
    offset = 0;
    FILE *fp;
    fp = fopen(filename, "rb");
    if (fp == nullptr)
    std::ifstream is(filename, std::ifstream::binary);
    is.read(reinterpret_cast<char *>(&header), sizeof(header));
    if(!is){
        fprintf(stderr, "Failed to read %s\n", filename);
        return false;
    fseek(fp, 0, SEEK_END);  /*定位到文件末尾*/
    uint32_t nFileLen = ftell(fp);  /*得到文件大小*/
    fseek(fp, 44, SEEK_SET);  /*跳过wav文件头*/
    speech_len = (nFileLen - 44) / 2;
    speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
    speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_align_len);
    }
    *sampling_rate = header.sample_rate;
    // header.subchunk2_size contains the number of bytes in the data.
    // As we assume each sample contains two bytes, so it is divided by 2 here
    speech_len = header.subchunk2_size / 2;
    speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
    if (speech_buff)
    {
        memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
        int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
        fclose(fp);
        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
        is.read(reinterpret_cast<char *>(speech_buff), header.subchunk2_size);
        if (!is) {
            fprintf(stderr, "Failed to read %s\n", filename);
            return false;
        }
        speech_data = (float*)malloc(sizeof(float) * speech_len);
        memset(speech_data, 0, sizeof(float) * speech_len);
        speech_data = (float*)malloc(sizeof(float) * speech_align_len);
        memset(speech_data, 0, sizeof(float) * speech_align_len);
        int i;
        float scale = 1;
        if (data_type == 1) {
            scale = 32768;
        }
        for (i = 0; i < speech_len; i++) {
        for (int32_t i = 0; i != speech_len; ++i) {
            speech_data[i] = (float)speech_buff[i] / scale;
        }
        //resample
        if(*sampling_rate != model_sample_rate){
            wavResample(*sampling_rate, speech_data, speech_len);
        }
        AudioFrame* frame = new AudioFrame(speech_len);
        frame_queue.push(frame);
        return true;
    }
@@ -174,57 +292,54 @@
        return false;
}
bool Audio::loadwav(const char* buf, int nFileLen)
bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
{
    WaveHeader header;
    if (speech_data != NULL) {
        free(speech_data);
    }
    if (speech_buff != NULL) {
        free(speech_buff);
    }
    offset = 0;
    size_t nOffset = 0;
    std::memcpy(&header, buf, sizeof(header));
#define WAV_HEADER_SIZE 44
    speech_len = (nFileLen - WAV_HEADER_SIZE) / 2;
    speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
    *sampling_rate = header.sample_rate;
    speech_len = header.subchunk2_size / 2;
    speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
    if (speech_buff)
    {
        memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
        memcpy((void*)speech_buff, (const void*)(buf + WAV_HEADER_SIZE), speech_len * sizeof(int16_t));
        speech_data = (float*)malloc(sizeof(float) * speech_len);
        memset(speech_data, 0, sizeof(float) * speech_len);
        speech_data = (float*)malloc(sizeof(float) * speech_align_len);
        memset(speech_data, 0, sizeof(float) * speech_align_len);
        int i;
        float scale = 1;
        if (data_type == 1) {
            scale = 32768;
        }
        for (i = 0; i < speech_len; i++) {
        for (int32_t i = 0; i != speech_len; ++i) {
            speech_data[i] = (float)speech_buff[i] / scale;
        }
        //resample
        if(*sampling_rate != model_sample_rate){
            wavResample(*sampling_rate, speech_data, speech_len);
        }
        AudioFrame* frame = new AudioFrame(speech_len);
        frame_queue.push(frame);
        return true;
    }
    else
        return false;
}
bool Audio::loadpcmwav(const char* buf, int nBufLen)
bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
{
    if (speech_data != NULL) {
        free(speech_data);
@@ -234,32 +349,28 @@
    }
    offset = 0;
    size_t nOffset = 0;
    speech_len = nBufLen / 2;
    speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
    if (speech_buff)
    {
        memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
        memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
        speech_data = (float*)malloc(sizeof(float) * speech_len);
        memset(speech_data, 0, sizeof(float) * speech_len);
        speech_data = (float*)malloc(sizeof(float) * speech_align_len);
        memset(speech_data, 0, sizeof(float) * speech_align_len);
        int i;
        float scale = 1;
        if (data_type == 1) {
            scale = 32768;
        }
        for (i = 0; i < speech_len; i++) {
        for (int32_t i = 0; i != speech_len; ++i) {
            speech_data[i] = (float)speech_buff[i] / scale;
        }
        //resample
        if(*sampling_rate != model_sample_rate){
            wavResample(*sampling_rate, speech_data, speech_len);
        }
        AudioFrame* frame = new AudioFrame(speech_len);
@@ -269,13 +380,10 @@
    }
    else
        return false;
}
bool Audio::loadpcmwav(const char* filename)
bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
{
    if (speech_data != NULL) {
        free(speech_data);
    }
@@ -293,34 +401,31 @@
    fseek(fp, 0, SEEK_SET);
    speech_len = (nFileLen) / 2;
    speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
    if (speech_buff)
    {
        memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
        int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
        fclose(fp);
        speech_data = (float*)malloc(sizeof(float) * speech_align_len);
        memset(speech_data, 0, sizeof(float) * speech_align_len);
        speech_data = (float*)malloc(sizeof(float) * speech_len);
        memset(speech_data, 0, sizeof(float) * speech_len);
        int i;
        float scale = 1;
        if (data_type == 1) {
            scale = 32768;
        }
        for (i = 0; i < speech_len; i++) {
        for (int32_t i = 0; i != speech_len; ++i) {
            speech_data[i] = (float)speech_buff[i] / scale;
        }
        //resample
        if(*sampling_rate != model_sample_rate){
            wavResample(*sampling_rate, speech_data, speech_len);
        }
        AudioFrame* frame = new AudioFrame(speech_len);
        frame_queue.push(frame);
    
        return true;
    }
@@ -328,7 +433,6 @@
        return false;
}
int Audio::fetch_chunck(float *&dout, int len)
{
funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,43 +1,44 @@
file(GLOB files1 "*.cpp")
file(GLOB files2 "*.cc")
file(GLOB files4 "paraformer/*.cpp")
set(files ${files1} ${files2} ${files3} ${files4})
# message("${files}")
add_library(rapidasr ${files})
add_library(funasr ${files})
if(WIN32)
        set(EXTRA_LIBS libfftw3f-3 yaml-cpp)
        if(CMAKE_CL_64)
            target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
            target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
        else()
            target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
            target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
        endif()
        target_include_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
        target_include_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
        
        target_compile_definitions(rapidasr PUBLIC -D_RPASR_API_EXPORT)
        target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
else()
    set(EXTRA_LIBS fftw3f pthread yaml-cpp)
    target_include_directories(rapidasr PUBLIC "/usr/local/opt/fftw/include")
    target_link_directories(rapidasr PUBLIC "/usr/local/opt/fftw/lib")
    target_include_directories(funasr PUBLIC "/usr/local/opt/fftw/include")
    target_link_directories(funasr PUBLIC "/usr/local/opt/fftw/lib")
    target_include_directories(rapidasr PUBLIC "/usr/local/opt/openblas/include")
    target_link_directories(rapidasr PUBLIC "/usr/local/opt/openblas/lib")
    target_include_directories(funasr PUBLIC "/usr/local/opt/openblas/include")
    target_link_directories(funasr PUBLIC "/usr/local/opt/openblas/lib")
    target_include_directories(rapidasr PUBLIC "/usr/include")
    target_link_directories(rapidasr PUBLIC "/usr/lib64")
    target_include_directories(funasr PUBLIC "/usr/include")
    target_link_directories(funasr PUBLIC "/usr/lib64")
    target_include_directories(rapidasr PUBLIC  ${FFTW3F_INCLUDE_DIR})
    target_link_directories(rapidasr PUBLIC ${FFTW3F_LIBRARY_DIR})
    target_include_directories(funasr PUBLIC  ${FFTW3F_INCLUDE_DIR})
    target_link_directories(funasr PUBLIC ${FFTW3F_LIBRARY_DIR})
    include_directories(${ONNXRUNTIME_DIR}/include)    
endif()
include_directories(${CMAKE_SOURCE_DIR}/include)
target_link_libraries(rapidasr PUBLIC onnxruntime ${EXTRA_LIBS})
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
funasr/runtime/onnxruntime/src/FeatureExtract.cpp
@@ -5,14 +5,10 @@
FeatureExtract::FeatureExtract(int mode) : mode(mode)
{
    fftw_init();
}
FeatureExtract::~FeatureExtract()
{
    fftwf_free(fft_input);
    fftwf_free(fft_out);
    fftwf_destroy_plan(p);
}
void FeatureExtract::reset()
@@ -26,34 +22,25 @@
    return fqueue.size();
}
void FeatureExtract::fftw_init()
void FeatureExtract::insert(fftwf_plan plan, float *din, int len, int flag)
{
    int fft_size = 512;
    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
    float* fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
    fftwf_complex* fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
    memset(fft_input, 0, sizeof(float) * fft_size);
    p = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
}
void FeatureExtract::insert(float *din, int len, int flag)
{
    const float *window = (const float *)&window_hex;
    if (mode == 3)
        window = (const float *)&window_hamm_hex;
    int window_size = 400;
    int fft_size = 512;
    int window_shift = 160;
    speech.load(din, len);
    int i, j;
    float tmp_feature[80];
    if (mode == 0 || mode == 2 || mode == 3) {
        int ll = (speech.size() - 400) / 160 + 1;
        int ll = (speech.size() - window_size) / window_shift + 1;
        fqueue.reinit(ll);
    }
    for (i = 0; i <= speech.size() - 400; i = i + window_shift) {
    for (i = 0; i <= speech.size() - window_size; i = i + window_shift) {
        float tmp_mean = 0;
        for (j = 0; j < window_size; j++) {
            tmp_mean += speech[i + j];
@@ -70,7 +57,7 @@
            pre_val = cur_val;
        }
        fftwf_execute(p);
        fftwf_execute_dft_r2c(plan, fft_input, fft_out);
        melspect((float *)fft_out, tmp_feature);
        int tmp_flag = S_MIDDLE;
@@ -80,6 +67,8 @@
        fqueue.push(tmp_feature, tmp_flag);
    }
    speech.update(i);
    fftwf_free(fft_input);
    fftwf_free(fft_out);
}
bool FeatureExtract::fetch(Tensor<float> *&dout)
@@ -128,7 +117,6 @@
void FeatureExtract::melspect(float *din, float *dout)
{
    float fftmag[256];
//    float tmp;
    const float *melcoe = (const float *)melcoe_hex;
    int i;
    for (i = 0; i < 256; i++) {
funasr/runtime/onnxruntime/src/FeatureExtract.h
@@ -14,12 +14,11 @@
    SpeechWrap speech;
    FeatureQueue fqueue;
    int mode;
    int fft_size = 512;
    int window_size = 400;
    int window_shift = 160;
    float *fft_input;
    fftwf_complex *fft_out;
    fftwf_plan p;
    void fftw_init();
    //void fftw_init();
    void melspect(float *din, float *dout);
    void global_cmvn(float *din);
@@ -27,9 +26,9 @@
    FeatureExtract(int mode);
    ~FeatureExtract();
    int size();
    int status();
    //int status();
    void reset();
    void insert(float *din, int len, int flag);
    void insert(fftwf_plan plan, float *din, int len, int flag);
    bool fetch(Tensor<float> *&dout);
};
funasr/runtime/onnxruntime/src/Vocab.cpp
@@ -13,21 +13,6 @@
{
    ifstream in(filename);
    loadVocabFromYaml(filename);
    /*
    string line;
    if (in) // 有该文件
    {
        while (getline(in, line)) // line中不包括每行的换行符
        {
            vocab.push_back(line);
        }
    }
    else{
        printf("Cannot load vocab from: %s, there must be file vocab.txt", filename);
        exit(-1);
    }
    */
}
Vocab::~Vocab()
{
funasr/runtime/onnxruntime/src/commonfunc.h
@@ -5,7 +5,7 @@
{
    std::string msg;
    float  snippet_time;
}RPASR_RECOG_RESULT;
}FUNASR_RECOG_RESULT;
#ifdef _WIN32
@@ -53,4 +53,4 @@
        }
    }
}
}
funasr/runtime/onnxruntime/src/libfunasrapi.cpp
New file
@@ -0,0 +1,184 @@
#include "precomp.h"
#ifdef __cplusplus
extern "C" {
#endif
    // APIs for qmasr
    _FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
    {
        Model* mm = create_model(szModelDir, nThreadNum, quantize);
        return mm;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
            return nullptr;
        int32_t sampling_rate = -1;
        Audio audio(1);
        if (!audio.loadwav(szBuf, nLen, &sampling_rate))
            return nullptr;
        //audio.split();
        float* buff;
        int len;
        int flag=0;
        FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
        pResult->snippet_time = audio.get_time_len();
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg += msg;
            nStep++;
            if (fnCallback)
                fnCallback(nStep, nTotal);
        }
        return pResult;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
            return nullptr;
        Audio audio(1);
        if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
            return nullptr;
        //audio.split();
        float* buff;
        int len;
        int flag = 0;
        FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
        pResult->snippet_time = audio.get_time_len();
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg += msg;
            nStep++;
            if (fnCallback)
                fnCallback(nStep, nTotal);
        }
        return pResult;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
            return nullptr;
        Audio audio(1);
        if (!audio.loadpcmwav(szFileName, &sampling_rate))
            return nullptr;
        //audio.split();
        float* buff;
        int len;
        int flag = 0;
        FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
        pResult->snippet_time = audio.get_time_len();
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg += msg;
            nStep++;
            if (fnCallback)
                fnCallback(nStep, nTotal);
        }
        return pResult;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
            return nullptr;
        int32_t sampling_rate = -1;
        Audio audio(1);
        if(!audio.loadwav(szWavfile, &sampling_rate))
            return nullptr;
        //audio.split();
        float* buff;
        int len;
        int flag = 0;
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
        pResult->snippet_time = audio.get_time_len();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg+= msg;
            nStep++;
            if (fnCallback)
                fnCallback(nStep, nTotal);
        }
        return pResult;
    }
    _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
    {
        if (!Result)
            return 0;
        return 1;
    }
    _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
    {
        if (!Result)
            return 0.0f;
        return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
    }
    _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
    {
        FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
        if(!pResult)
            return nullptr;
        return pResult->msg.c_str();
    }
    _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
    {
        if (Result)
        {
            delete (FUNASR_RECOG_RESULT*)Result;
        }
    }
    _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
            return;
        delete pRecogObj;
    }
#ifdef __cplusplus
}
#endif
funasr/runtime/onnxruntime/src/librapidasrapi.cpp
File was deleted
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,7 +4,7 @@
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
{
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
    string model_path;
    string cmvn_path;
    string config_path;
@@ -18,7 +18,10 @@
    cmvn_path = pathAppend(path, "am.mvn");
    config_path = pathAppend(path, "config.yaml");
    fe = new FeatureExtract(3);
    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
    memset(fft_input, 0, sizeof(float) * fft_size);
    plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
    //sessionOptions.SetInterOpNumThreads(1);
    sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -26,20 +29,20 @@
#ifdef _WIN32
    wstring wstrPath = strToWstr(model_path);
    m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
    m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
    string strName;
    getInputName(m_session, strName);
    getInputName(m_session.get(), strName);
    m_strInputNames.push_back(strName.c_str());
    getInputName(m_session, strName,1);
    getInputName(m_session.get(), strName,1);
    m_strInputNames.push_back(strName);
    
    getOutputName(m_session, strName);
    getOutputName(m_session.get(), strName);
    m_strOutputNames.push_back(strName);
    getOutputName(m_session, strName,1);
    getOutputName(m_session.get(), strName,1);
    m_strOutputNames.push_back(strName);
    for (auto& item : m_strInputNames)
@@ -52,20 +55,16 @@
ModelImp::~ModelImp()
{
    if(fe)
        delete fe;
    if (m_session)
    {
        delete m_session;
        m_session = nullptr;
    }
    if(vocab)
        delete vocab;
    fftwf_free(fft_input);
    fftwf_free(fft_out);
    fftwf_destroy_plan(plan);
    fftwf_cleanup();
}
void ModelImp::reset()
{
    fe->reset();
}
void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -159,13 +158,20 @@
string ModelImp::forward(float* din, int len, int flag)
{
    Tensor<float>* in;
    fe->insert(din, len, flag);
    FeatureExtract* fe = new FeatureExtract(3);
    fe->reset();
    fe->insert(plan, din, len, flag);
    fe->fetch(in);
    apply_lfr(in);
    apply_cmvn(in);
    Ort::RunOptions run_option;
#ifdef _WIN_X86
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
    std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
    Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
@@ -192,7 +198,6 @@
        auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>();
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@@ -203,9 +208,14 @@
        result = "";
    }
    if(in)
    if(in){
        delete in;
        in = nullptr;
    }
    if(fe){
        delete fe;
        fe = nullptr;
    }
    return result;
}
funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -8,7 +8,10 @@
    class ModelImp : public Model {
    private:
        FeatureExtract* fe;
        int fft_size=512;
        float *fft_input;
        fftwf_complex *fft_out;
        fftwf_plan plan;
        Vocab* vocab;
        vector<float> means_list;
@@ -21,21 +24,13 @@
        string greedy_search( float* in, int nLen);
#ifdef _WIN_X86
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
        Ort::Session* m_session = nullptr;
        Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
        Ort::SessionOptions sessionOptions = Ort::SessionOptions();
        std::unique_ptr<Ort::Session> m_session;
        Ort::Env env_;
        Ort::SessionOptions sessionOptions;
        vector<string> m_strInputNames, m_strOutputNames;
        vector<const char*> m_szInputNames;
        vector<const char*> m_szOutputNames;
        //string m_strInputName, m_strInputNameLen;
        //string m_strOutputName, m_strOutputNameLen;
    public:
        ModelImp(const char* path, int nNumThread=0, bool quantize=false);
funasr/runtime/onnxruntime/src/precomp.h
@@ -44,9 +44,10 @@
#include "FeatureQueue.h"
#include "SpeechWrap.h"
#include <Audio.h>
#include "resample.h"
#include "Model.h"
#include "paraformer_onnx.h"
#include "librapidasrapi.h"
#include "libfunasrapi.h"
using namespace paraformer;
funasr/runtime/onnxruntime/src/resample.cc
New file
@@ -0,0 +1,305 @@
/**
 * Copyright     2013  Pegah Ghahremani
 *               2014  IMSL, PKU-HKUST (author: Wei Shi)
 *               2014  Yanqing Sun, Junjie Wang
 *               2014  Johns Hopkins University (author: Daniel Povey)
 * Copyright     2023  Xiaomi Corporation (authors: Fangjun Kuang)
 *
 * See LICENSE for clarification regarding multiple authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
// this file is copied and modified from
// kaldi/src/feat/resample.cc
#include "resample.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <cstdlib>
#include <type_traits>
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
template <class I>
I Gcd(I m, I n) {
  // this function is copied from kaldi/src/base/kaldi-math.h
  if (m == 0 || n == 0) {
    if (m == 0 && n == 0) {  // gcd not defined, as all integers are divisors.
      fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
      exit(-1);
    }
    return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
    // return absolute value of whichever is nonzero
  }
  // could use compile-time assertion
  // but involves messing with complex template stuff.
  static_assert(std::is_integral<I>::value, "");
  while (1) {
    m %= n;
    if (m == 0) return (n > 0 ? n : -n);
    n %= m;
    if (n == 0) return (m > 0 ? m : -m);
  }
}
/// Returns the least common multiple of two integers.  Will
/// crash unless the inputs are positive.
template <class I>
I Lcm(I m, I n) {
  // This function is copied from kaldi/src/base/kaldi-math.h
  assert(m > 0 && n > 0);
  I gcd = Gcd(m, n);
  return gcd * (m / gcd) * (n / gcd);
}
static float DotProduct(const float *a, const float *b, int32_t n) {
  float sum = 0;
  for (int32_t i = 0; i != n; ++i) {
    sum += a[i] * b[i];
  }
  return sum;
}
LinearResample::LinearResample(int32_t samp_rate_in_hz,
                               int32_t samp_rate_out_hz, float filter_cutoff_hz,
                               int32_t num_zeros)
    : samp_rate_in_(samp_rate_in_hz),
      samp_rate_out_(samp_rate_out_hz),
      filter_cutoff_(filter_cutoff_hz),
      num_zeros_(num_zeros) {
  assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 &&
         filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz &&
         filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0);
  // base_freq is the frequency of the repeating unit, which is the gcd
  // of the input frequencies.
  int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_);
  input_samples_in_unit_ = samp_rate_in_ / base_freq;
  output_samples_in_unit_ = samp_rate_out_ / base_freq;
  SetIndexesAndWeights();
  Reset();
}
void LinearResample::SetIndexesAndWeights() {
  first_index_.resize(output_samples_in_unit_);
  weights_.resize(output_samples_in_unit_);
  double window_width = num_zeros_ / (2.0 * filter_cutoff_);
  for (int32_t i = 0; i < output_samples_in_unit_; i++) {
    double output_t = i / static_cast<double>(samp_rate_out_);
    double min_t = output_t - window_width, max_t = output_t + window_width;
    // we do ceil on the min and floor on the max, because if we did it
    // the other way around we would unnecessarily include indexes just
    // outside the window, with zero coefficients.  It's possible
    // if the arguments to the ceil and floor expressions are integers
    // (e.g. if filter_cutoff_ has an exact ratio with the sample rates),
    // that we unnecessarily include something with a zero coefficient,
    // but this is only a slight efficiency issue.
    int32_t min_input_index = ceil(min_t * samp_rate_in_),
            max_input_index = floor(max_t * samp_rate_in_),
            num_indices = max_input_index - min_input_index + 1;
    first_index_[i] = min_input_index;
    weights_[i].resize(num_indices);
    for (int32_t j = 0; j < num_indices; j++) {
      int32_t input_index = min_input_index + j;
      double input_t = input_index / static_cast<double>(samp_rate_in_),
             delta_t = input_t - output_t;
      // sign of delta_t doesn't matter.
      weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_;
    }
  }
}
/** Here, t is a time in seconds representing an offset from
    the center of the windowed filter function, and FilterFunction(t)
    returns the windowed filter function, described
    in the header as h(t) = f(t)g(t), evaluated at t.
*/
float LinearResample::FilterFunc(float t) const {
  float window,  // raised-cosine (Hanning) window of width
                 // num_zeros_/2*filter_cutoff_
      filter;    // sinc filter function
  if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
    window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t));
  else
    window = 0.0;  // outside support of window function
  if (t != 0)
    filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t);
  else
    filter = 2 * filter_cutoff_;  // limit of the function at t = 0
  return filter * window;
}
void LinearResample::Reset() {
  input_sample_offset_ = 0;
  output_sample_offset_ = 0;
  input_remainder_.resize(0);
}
void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
                              std::vector<float> *output) {
  int64_t tot_input_samp = input_sample_offset_ + input_dim,
          tot_output_samp = GetNumOutputSamples(tot_input_samp, flush);
  assert(tot_output_samp >= output_sample_offset_);
  output->resize(tot_output_samp - output_sample_offset_);
  // samp_out is the index into the total output signal, not just the part
  // of it we are producing here.
  for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp;
       samp_out++) {
    int64_t first_samp_in;
    int32_t samp_out_wrapped;
    GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped);
    const std::vector<float> &weights = weights_[samp_out_wrapped];
    // first_input_index is the first index into "input" that we have a weight
    // for.
    int32_t first_input_index =
        static_cast<int32_t>(first_samp_in - input_sample_offset_);
    float this_output;
    if (first_input_index >= 0 &&
        first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) {
      this_output =
          DotProduct(input + first_input_index, weights.data(), weights.size());
    } else {  // Handle edge cases.
      this_output = 0.0;
      for (int32_t i = 0; i < static_cast<int32_t>(weights.size()); i++) {
        float weight = weights[i];
        int32_t input_index = first_input_index + i;
        if (input_index < 0 &&
            static_cast<int32_t>(input_remainder_.size()) + input_index >= 0) {
          this_output +=
              weight * input_remainder_[input_remainder_.size() + input_index];
        } else if (input_index >= 0 && input_index < input_dim) {
          this_output += weight * input[input_index];
        } else if (input_index >= input_dim) {
          // We're past the end of the input and are adding zero; should only
          // happen if the user specified flush == true, or else we would not
          // be trying to output this sample.
          assert(flush);
        }
      }
    }
    int32_t output_index =
        static_cast<int32_t>(samp_out - output_sample_offset_);
    (*output)[output_index] = this_output;
  }
  if (flush) {
    Reset();  // Reset the internal state.
  } else {
    SetRemainder(input, input_dim);
    input_sample_offset_ = tot_input_samp;
    output_sample_offset_ = tot_output_samp;
  }
}
int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
                                            bool flush) const {
  // For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
  // where tick_freq is the least common multiple of samp_rate_in_ and
  // samp_rate_out_.
  int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_);
  int32_t ticks_per_input_period = tick_freq / samp_rate_in_;
  // work out the number of ticks in the time interval
  // [ 0, input_num_samp/samp_rate_in_ ).
  int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period;
  if (!flush) {
    float window_width = num_zeros_ / (2.0 * filter_cutoff_);
    // To count the window-width in ticks we take the floor.  This
    // is because since we're looking for the largest integer num-out-samp
    // that fits in the interval, which is open on the right, a reduction
    // in interval length of less than a tick will never make a difference.
    // For example, the largest integer in the interval [ 0, 2 ) and the
    // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
    // So when we're subtracting the window-width we can ignore the fractional
    // part.
    int32_t window_width_ticks = floor(window_width * tick_freq);
    // The time-period of the output that we can sample gets reduced
    // by the window-width (which is actually the distance from the
    // center to the edge of the windowing function) if we're not
    // "flushing the output".
    interval_length_in_ticks -= window_width_ticks;
  }
  if (interval_length_in_ticks <= 0) return 0;
  int32_t ticks_per_output_period = tick_freq / samp_rate_out_;
  // Get the last output-sample in the closed interval, i.e. replacing [ ) with
  // [ ].  Note: integer division rounds down.  See
  // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
  // the notation.
  int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period;
  // We need the last output-sample in the open interval, so if it takes us to
  // the end of the interval exactly, subtract one.
  if (last_output_samp * ticks_per_output_period == interval_length_in_ticks)
    last_output_samp--;
  // First output-sample index is zero, so the number of output samples
  // is the last output-sample plus one.
  int64_t num_output_samp = last_output_samp + 1;
  return num_output_samp;
}
// inline
void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in,
                                int32_t *samp_out_wrapped) const {
  // A unit is the smallest nonzero amount of time that is an exact
  // multiple of the input and output sample periods.  The unit index
  // is the answer to "which numbered unit we are in".
  int64_t unit_index = samp_out / output_samples_in_unit_;
  // samp_out_wrapped is equal to samp_out % output_samples_in_unit_
  *samp_out_wrapped =
      static_cast<int32_t>(samp_out - unit_index * output_samples_in_unit_);
  *first_samp_in =
      first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_;
}
void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
  std::vector<float> old_remainder(input_remainder_);
  // max_remainder_needed is the width of the filter from side to side,
  // measured in input samples.  you might think it should be half that,
  // but you have to consider that you might be wanting to output samples
  // that are "in the past" relative to the beginning of the latest
  // input... anyway, storing more remainder than needed is not harmful.
  int32_t max_remainder_needed =
      ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
  input_remainder_.resize(max_remainder_needed);
  for (int32_t index = -static_cast<int32_t>(input_remainder_.size());
       index < 0; index++) {
    // we interpret "index" as an offset from the end of "input" and
    // from the end of input_remainder_.
    int32_t input_index = index + input_dim;
    if (input_index >= 0) {
      input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
          input[input_index];
    } else if (input_index + static_cast<int32_t>(old_remainder.size()) >= 0) {
      input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
          old_remainder[input_index +
                        static_cast<int32_t>(old_remainder.size())];
      // else leave it at zero.
    }
  }
}
funasr/runtime/onnxruntime/src/resample.h
New file
@@ -0,0 +1,137 @@
/**
 * Copyright     2013  Pegah Ghahremani
 *               2014  IMSL, PKU-HKUST (author: Wei Shi)
 *               2014  Yanqing Sun, Junjie Wang
 *               2014  Johns Hopkins University (author: Daniel Povey)
 * Copyright     2023  Xiaomi Corporation (authors: Fangjun Kuang)
 *
 * See LICENSE for clarification regarding multiple authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
// this file is copied and modified from
// kaldi/src/feat/resample.h
#include <cstdint>
#include <vector>
/*
   We require that the input and output sampling rate be specified as
   integers, as this is an easy way to specify that their ratio be rational.
*/
class LinearResample {
 public:
  /// Constructor.  We make the input and output sample rates integers, because
  /// we are going to need to find a common divisor.  This should just remind
  /// you that they need to be integers.  The filter cutoff needs to be less
  /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.  num_zeros
  /// controls the sharpness of the filter, more == sharper but less efficient.
  /// We suggest around 4 to 10 for normal use.
  LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz,
                 float filter_cutoff_hz, int32_t num_zeros);
  /// Calling the function Reset() resets the state of the object prior to
  /// processing a new signal; it is only necessary if you have called
  /// Resample(x, x_size, false, y) for some signal, leading to a remainder of
  /// the signal being called, but then abandon processing the signal before
  /// calling Resample(x, x_size, true, y) for the last piece.  Call it
  /// unnecessarily between signals will not do any harm.
  void Reset();
  /// This function does the resampling.  If you call it with flush == true and
  /// you have never called it with flush == false, it just resamples the input
  /// signal (it resizes the output to a suitable number of samples).
  ///
  /// You can also use this function to process a signal a piece at a time.
  /// suppose you break it into piece1, piece2, ... pieceN.  You can call
  /// \code{.cc}
  /// Resample(piece1, piece1_size, false, &output1);
  /// Resample(piece2, piece2_size, false, &output2);
  /// Resample(piece3, piece3_size, true, &output3);
  /// \endcode
  /// If you call it with flush == false, it won't output the last few samples
  /// but will remember them, so that if you later give it a second piece of
  /// the input signal it can process it correctly.
  /// If your most recent call to the object was with flush == false, it will
  /// have internal state; you can remove this by calling Reset().
  /// Empty input is acceptable.
  void Resample(const float *input, int32_t input_dim, bool flush,
                std::vector<float> *output);
  //// Return the input and output sampling rates (for checks, for example)
  int32_t GetInputSamplingRate() const { return samp_rate_in_; }
  int32_t GetOutputSamplingRate() const { return samp_rate_out_; }
 private:
  void SetIndexesAndWeights();
  float FilterFunc(float) const;
  /// This function outputs the number of output samples we will output
  /// for a signal with "input_num_samp" input samples.  If flush == true,
  /// we return the largest n such that
  /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ),
  /// and note that the interval is half-open.  If flush == false,
  /// define window_width as num_zeros / (2.0 * filter_cutoff_);
  /// we return the largest n such that (n/samp_rate_out_) is in the interval
  /// [ 0, input_num_samp/samp_rate_in_ - window_width ).
  int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const;
  /// Given an output-sample index, this function outputs to *first_samp_in the
  /// first input-sample index that we have a weight on (may be negative),
  /// and to *samp_out_wrapped the index into weights_ where we can get the
  /// corresponding weights on the input.
  inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in,
                         int32_t *samp_out_wrapped) const;
  void SetRemainder(const float *input, int32_t input_dim);
 private:
  // The following variables are provided by the user.
  int32_t samp_rate_in_;
  int32_t samp_rate_out_;
  float filter_cutoff_;
  int32_t num_zeros_;
  int32_t input_samples_in_unit_;  ///< The number of input samples in the
                                   ///< smallest repeating unit: num_samp_in_ =
                                   ///< samp_rate_in_hz / Gcd(samp_rate_in_hz,
                                   ///< samp_rate_out_hz)
  int32_t output_samples_in_unit_;  ///< The number of output samples in the
                                    ///< smallest repeating unit: num_samp_out_
                                    ///< = samp_rate_out_hz /
                                    ///< Gcd(samp_rate_in_hz, samp_rate_out_hz)
  /// The first input-sample index that we sum over, for this output-sample
  /// index.  May be negative; any truncation at the beginning is handled
  /// separately.  This is just for the first few output samples, but we can
  /// extrapolate the correct input-sample index for arbitrary output samples.
  std::vector<int32_t> first_index_;
  /// Weights on the input samples, for this output-sample index.
  std::vector<std::vector<float>> weights_;
  // the following variables keep track of where we are in a particular signal,
  // if it is being provided over multiple calls to Resample().
  int64_t input_sample_offset_;   ///< The number of input samples we have
                                  ///< already received for this signal
                                  ///< (including anything in remainder_)
  int64_t output_sample_offset_;  ///< The number of samples we have already
                                  ///< output for this signal.
  std::vector<float> input_remainder_;  ///< A small trailing part of the
                                        ///< previously seen input signal.
};
funasr/runtime/onnxruntime/tester/CMakeLists.txt
@@ -8,7 +8,7 @@
    endif()
endif()
set(EXTRA_LIBS rapidasr)
set(EXTRA_LIBS funasr)
include_directories(${CMAKE_SOURCE_DIR}/include)
funasr/runtime/onnxruntime/tester/tester.cpp
@@ -5,7 +5,7 @@
#include <win_func.h>
#endif
#include "librapidasrapi.h"
#include "libfunasrapi.h"
#include <iostream>
#include <fstream>
@@ -26,7 +26,7 @@
    // is quantize
    bool quantize = false;
    istringstream(argv[3]) >> boolalpha >> quantize;
    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
    if (!AsrHanlde)
    {
@@ -42,62 +42,22 @@
    gettimeofday(&start, NULL);
    float snippet_time = 0.0f;
    RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
    gettimeofday(&end, NULL);
   
    if (Result)
    {
        string msg = RapidAsrGetResult(Result, 0);
        string msg = FunASRGetResult(Result, 0);
        setbuf(stdout, NULL);
        cout << "Result: \"";
        cout << msg << "\"." << endl;
        snippet_time = RapidAsrGetRetSnippetTime(Result);
        RapidAsrFreeResult(Result);
        printf("Result: %s \n", msg.c_str());
        snippet_time = FunASRGetRetSnippetTime(Result);
        FunASRFreeResult(Result);
    }
    else
    {
        cout <<"no return data!";
    }
    //char* buff = nullptr;
    //int len = 0;
    //ifstream ifs(argv[2], std::ios::binary | std::ios::in);
    //if (ifs.is_open())
    //{
    //    ifs.seekg(0, std::ios::end);
    //    len = ifs.tellg();
    //    ifs.seekg(0, std::ios::beg);
    //    buff = new char[len];
    //    ifs.read(buff, len);
    //    //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
    //    RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL);
    //    //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
    //    gettimeofday(&end, NULL);
    //
    //    if (Result)
    //    {
    //        string msg = RapidAsrGetResult(Result, 0);
    //        setbuf(stdout, NULL);
    //        cout << "Result: \"";
    //        cout << msg << endl;
    //        cout << "\"." << endl;
    //        snippet_time = RapidAsrGetRetSnippetTime(Result);
    //        RapidAsrFreeResult(Result);
    //    }
    //    else
    //    {
    //        cout <<"no return data!";
    //    }
    //
    //delete[]buff;
    //}
 
    printf("Audio length %lfs.\n", (double)snippet_time);
    seconds = (end.tv_sec - start.tv_sec);
@@ -105,9 +65,9 @@
    printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
    printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
    RapidAsrUninit(AsrHanlde);
    FunASRUninit(AsrHanlde);
    return 0;
}
funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -5,7 +5,7 @@
#include <win_func.h>
#endif
#include "librapidasrapi.h"
#include "libfunasrapi.h"
#include <iostream>
#include <fstream>
@@ -47,7 +47,7 @@
    bool quantize = false;
    istringstream(argv[3]) >> boolalpha >> quantize;
    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
    if (!AsrHanlde)
    {
        printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
@@ -61,7 +61,7 @@
    // warm up
    for (size_t i = 0; i < 30; i++)
    {
        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
    }
    // forward
@@ -72,19 +72,19 @@
    for (size_t i = 0; i < wav_list.size(); i++)
    {
        gettimeofday(&start, NULL);
        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
        total_time += taking_micros;
        if(Result){
            string msg = RapidAsrGetResult(Result, 0);
            printf("Result: %s \n", msg);
            string msg = FunASRGetResult(Result, 0);
            printf("Result: %s \n", msg.c_str());
            snippet_time = RapidAsrGetRetSnippetTime(Result);
            snippet_time = FunASRGetRetSnippetTime(Result);
            total_length += snippet_time;
            RapidAsrFreeResult(Result);
            FunASRFreeResult(Result);
        }else{
            cout <<"No return data!";
        }
@@ -94,6 +94,6 @@
    printf("total_time_comput %ld ms.\n", total_time / 1000);
    printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
    RapidAsrUninit(AsrHanlde);
    FunASRUninit(AsrHanlde);
    return 0;
}
funasr/runtime/python/grpc/grpc_main_client.py
New file
@@ -0,0 +1,62 @@
import grpc
import json
import time
import asyncio
import soundfile as sf
import argparse
from grpc_client import transcribe_audio_bytes
from paraformer_pb2_grpc import ASRStub
# send the audio data once
async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
    with grpc.insecure_channel(grpc_uri) as channel:
        stub = ASRStub(channel)
        for line in wav_scp:
            wav_file = line.split()[1]
            wav, _ = sf.read(wav_file, dtype='int16')
            b = time.time()
            response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
            resp = response.next()
            text = ''
            if 'decoding' == resp.action:
                resp = response.next()
                if 'finish' == resp.action:
                    text = json.loads(resp.sentence)['text']
            response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
            res= {'text': text, 'time': time.time() - b}
            print(res)
async def test(args):
    wav_scp = open(args.wav_scp, "r").readlines()
    uri = '{}:{}'.format(args.host, args.port)
    res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--host",
                        type=str,
                        default="127.0.0.1",
                        required=False,
                        help="grpc server host ip")
    parser.add_argument("--port",
                        type=int,
                        default=10108,
                        required=False,
                        help="grpc server port")
    parser.add_argument("--user_allowed",
                        type=str,
                        default="project1_user1",
                        help="allowed user for grpc client")
    parser.add_argument("--sample_rate",
                        type=int,
                        default=16000,
                        help="audio sample_rate from client")
    parser.add_argument("--wav_scp",
                        type=str,
                        required=True,
                        help="audio wav scp")
    args = parser.parse_args()
    asyncio.run(test(args))
funasr/runtime/python/grpc/grpc_server.py
@@ -109,7 +109,7 @@
                            else:
                                asr_result = ""
                        elif self.backend == "onnxruntime":
                            from rapid_paraformer.utils.frontend import load_bytes
                            from funasr_onnx.utils.frontend import load_bytes
                            array = load_bytes(tmp_data)
                            asr_result = self.inference_16k_pipeline(array)[0]
                        end_time = int(round(time.time() * 1000))
funasr/runtime/python/libtorch/README.md
@@ -2,8 +2,6 @@
[FunASR](https://github.com/alibaba-damo-academy/FunASR) hopes to build a bridge between academic research and industrial applications on speech recognition. By supporting the training & finetuning of the industrial-grade speech recognition model released on ModelScope, researchers and developers can conduct research and production of speech recognition models more conveniently, and promote the development of speech recognition ecology. ASR for Fun!
### Introduction
- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
### Steps:
1. Export the model.
@@ -25,15 +23,20 @@
    
    install from pip
    ```shell
    pip install --upgrade funasr_torch -i https://pypi.Python.org/simple
    pip install -U funasr_torch
    # For the users in China, you could install with the command:
    # pip install -U funasr_torch -i https://mirror.sjtu.edu.cn/pypi/web/simple
    ```
    or install from source code
    ```shell
    git clone https://github.com/alibaba/FunASR.git && cd FunASR
    cd funasr/runtime/python/libtorch
    python setup.py build
    python setup.py install
    pip install -e ./
    # For the users in China, you could install with the command:
    # pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
    ```
3. Run the demo.
funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -23,7 +23,6 @@
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 plot_timestamp_to: str = "",
                 pred_bias: int = 1,
                 quantize: bool = False,
                 intra_op_num_threads: int = 1,
                 ):
@@ -48,7 +47,10 @@
        self.batch_size = batch_size
        self.device_id = device_id
        self.plot_timestamp_to = plot_timestamp_to
        self.pred_bias = pred_bias
        if "predictor_bias" in config['model_conf'].keys():
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -23,9 +23,11 @@
                 ):
        check_argument_types()
        # self.token_list = self.load_token(token_path)
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
        self.token2id = {v: i for i, v in enumerate(self.token_list)}
        self.unk_id = self.token2id[self.unk_symbol]
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
@@ -38,13 +40,8 @@
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        token2id = {v: i for i, v in enumerate(self.token_list)}
        if self.unk_symbol not in token2id:
            raise TokenIDConverterError(
                f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
            )
        unk_id = token2id[self.unk_symbol]
        return [token2id.get(i, unk_id) for i in tokens]
        return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
@@ -134,7 +131,7 @@
@functools.lru_cache()
def get_logger(name='torch_paraformer'):
def get_logger(name='funasr_torch'):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
funasr/runtime/python/libtorch/setup.py
@@ -15,10 +15,10 @@
setuptools.setup(
    name='funasr_torch',
    version='0.0.3',
    version='0.0.4',
    platforms="Any",
    url="https://github.com/alibaba-damo-academy/FunASR.git",
    author="Speech Lab, Alibaba Group, China",
    author="Speech Lab of DAMO Academy, Alibaba Group",
    author_email="funasr@list.alibaba-inc.com",
    description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
    license="The MIT License",
@@ -31,7 +31,7 @@
                      "PyYAML>=5.1.2", "torch-quant >= 0.4.0"],
    packages=find_packages(include=["torch_paraformer*"]),
    keywords=[
        'funasr,paraformer, funasr_torch'
        'funasr, paraformer, funasr_torch'
    ],
    classifiers=[
        'Programming Language :: Python :: 3.6',
funasr/runtime/python/onnxruntime/README.md
@@ -1,10 +1,6 @@
## Using funasr with ONNXRuntime
### Introduction
- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
### Steps:
1. Export the model.
   - Command: (`Tips`: torch >= 1.11.0 is required.)
@@ -25,7 +21,10 @@
install from pip
```shell
pip install --upgrade funasr_onnx -i https://pypi.Python.org/simple
pip install -U funasr_onnx
# For the users in China, you could install with the command:
# pip install -U funasr_onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
or install from source code
@@ -33,8 +32,10 @@
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/onnxruntime
python setup.py build
python setup.py install
pip install -e ./
# For the users in China, you could install with the command:
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
3. Run the demo.
funasr/runtime/python/onnxruntime/demo.py
@@ -1,5 +1,6 @@
from funasr_onnx import Paraformer
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0)  # cpu
@@ -7,7 +8,6 @@
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
wav_path = "YourPath/xx.wav"
funasr/runtime/python/onnxruntime/demo_punc_offline.py
New file
@@ -0,0 +1,8 @@
from funasr_onnx import CT_Transformer
model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model = CT_Transformer(model_dir)
text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
result = model(text_in)
print(result[0])
funasr/runtime/python/onnxruntime/demo_punc_online.py
New file
@@ -0,0 +1,15 @@
from funasr_onnx import CT_Transformer_VadRealtime
model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model = CT_Transformer_VadRealtime(model_dir)
text_in  = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
vads = text_in.split("|")
rec_result_all=""
param_dict = {"cache": []}
for vad in vads:
    result = model(vad, param_dict=param_dict)
    rec_result_all += result[0]
print(rec_result_all)
funasr/runtime/python/onnxruntime/demo_vad_offline.py
New file
@@ -0,0 +1,11 @@
import soundfile
from funasr_onnx import Fsmn_vad
model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
model = Fsmn_vad(model_dir)
#offline vad
result = model(wav_path)
print(result)
funasr/runtime/python/onnxruntime/demo_vad_online.py
New file
@@ -0,0 +1,28 @@
import soundfile
from funasr_onnx import Fsmn_vad_online
model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
model = Fsmn_vad_online(model_dir)
##online vad
speech, sample_rate = soundfile.read(wav_path)
speech_length = speech.shape[0]
#
sample_offset = 0
step = 1600
param_dict = {'in_cache': []}
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
    if sample_offset + step >= speech_length - 1:
        step = speech_length - sample_offset
        is_final = True
    else:
        is_final = False
    param_dict['is_final'] = is_final
    segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
                            param_dict=param_dict)
    if segments_result:
        print(segments_result)
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,2 +1,6 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -23,7 +23,6 @@
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 plot_timestamp_to: str = "",
                 pred_bias: int = 1,
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 ):
@@ -47,7 +46,10 @@
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.plot_timestamp_to = plot_timestamp_to
        self.pred_bias = pred_bias
        if "predictor_bias" in config['model_conf'].keys():
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
New file
@@ -0,0 +1,259 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import numpy as np
from .utils.utils import (ONNXRuntimeError,
                          OrtInferSession, get_logger,
                          read_yaml)
from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
logging = get_logger()
class CT_Transformer():
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        config_file = os.path.join(model_dir, 'punc.yaml')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(config['token_list'])
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = 1
        self.punc_list = config['punc_list']
        self.period = 0
        for i in range(len(self.punc_list)):
            if self.punc_list[i] == ",":
                self.punc_list[i] = ","
            elif self.punc_list[i] == "?":
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
    def __call__(self, text: Union[list, str], split_size=20):
        split_text = code_mix_split_words(text)
        split_text_id = self.converter.tokens2ids(split_text)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
        assert len(mini_sentences) == len(mini_sentences_id)
        cache_sent = []
        cache_sent_id = []
        new_mini_sentence = ""
        new_mini_sentence_punc = []
        cache_pop_trigger_limit = 200
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
            data = {
                "text": mini_sentence_id[None,:],
                "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
            }
            try:
                outputs = self.infer(data['text'], data['text_lengths'])
                y = outputs[0]
                punctuations = np.argmax(y,axis=-1)[0]
                assert punctuations.size == len(mini_sentence)
            except ONNXRuntimeError:
                logging.warning("error")
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                last_comma_index = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = self.period
                cache_sent = mini_sentence[sentenceEnd + 1:]
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            new_mini_sentence_punc += [int(x) for x in punctuations]
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if i > 0:
                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                        mini_sentence[i] = " " + mini_sentence[i]
                words_with_punc.append(mini_sentence[i])
                if self.punc_list[punctuations[i]] != "_":
                    words_with_punc.append(self.punc_list[punctuations[i]])
            new_mini_sentence += "".join(words_with_punc)
            # Add Period for the end of the sentence
            new_mini_sentence_out = new_mini_sentence
            new_mini_sentence_punc_out = new_mini_sentence_punc
            if mini_sentence_i == len(mini_sentences) - 1:
                if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
                    new_mini_sentence_out = new_mini_sentence + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
        return new_mini_sentence_out, new_mini_sentence_punc_out
    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer([feats, feats_len])
        return outputs
class CT_Transformer_VadRealtime(CT_Transformer):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4
                 ):
        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
    def __call__(self, text: str, param_dict: map, split_size=20):
        cache_key = "cache"
        assert cache_key in param_dict
        cache = param_dict[cache_key]
        if cache is not None and len(cache) > 0:
            precache = "".join(cache)
        else:
            precache = ""
            cache = []
        full_text = precache + text
        split_text = code_mix_split_words(full_text)
        split_text_id = self.converter.tokens2ids(split_text)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
        new_mini_sentence_punc = []
        assert len(mini_sentences) == len(mini_sentences_id)
        cache_sent = []
        cache_sent_id = np.array([], dtype='int32')
        sentence_punc_list = []
        sentence_words_list = []
        cache_pop_trigger_limit = 200
        skip_num = 0
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
            text_length = len(mini_sentence_id)
            data = {
                "input": mini_sentence_id[None,:],
                "text_lengths": np.array([text_length], dtype='int32'),
                "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
            }
            try:
                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
                y = outputs[0]
                punctuations = np.argmax(y,axis=-1)[0]
                assert punctuations.size == len(mini_sentence)
            except ONNXRuntimeError:
                logging.warning("error")
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                last_comma_index = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = self.period
                cache_sent = mini_sentence[sentenceEnd + 1:]
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            punctuations_np = [int(x) for x in punctuations]
            new_mini_sentence_punc += punctuations_np
            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
            sentence_words_list += mini_sentence
        assert len(sentence_punc_list) == len(sentence_words_list)
        words_with_punc = []
        sentence_punc_list_out = []
        for i in range(0, len(sentence_words_list)):
            if i > 0:
                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
                    sentence_words_list[i] = " " + sentence_words_list[i]
            if skip_num < len(cache):
                skip_num += 1
            else:
                words_with_punc.append(sentence_words_list[i])
            if skip_num >= len(cache):
                sentence_punc_list_out.append(sentence_punc_list[i])
                if sentence_punc_list[i] != "_":
                    words_with_punc.append(sentence_punc_list[i])
        sentence_out = "".join(words_with_punc)
        sentenceEnd = -1
        for i in range(len(sentence_punc_list) - 2, 1, -1):
            if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
                sentenceEnd = i
                break
        cache_out = sentence_words_list[sentenceEnd + 1:]
        if sentence_out[-1] in self.punc_list:
            sentence_out = sentence_out[:-1]
            sentence_punc_list_out[-1] = "_"
        param_dict[cache_key] = cache_out
        return sentence_out, sentence_punc_list_out, cache_out
    def vad_mask(self, size, vad_pos, dtype=np.bool):
        """Create mask for decoder self-attention.
        :param int size: size of mask
        :param int vad_pos: index of vad index
        :param torch.dtype dtype: result dtype
        :rtype: torch.Tensor (B, Lmax, Lmax)
        """
        ret = np.ones((size, size), dtype=dtype)
        if vad_pos <= 0 or vad_pos >= size:
            return ret
        sub_corner = np.zeros(
            (vad_pos - 1, size - vad_pos), dtype=dtype)
        ret[0:vad_pos - 1, vad_pos:] = sub_corner
        return ret
    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray,
              vad_mask: np.ndarray,
              sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
        return outputs
funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
New file
@@ -0,0 +1,616 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import math
import numpy as np
class VadStateMachine(Enum):
    kVadInStateStartPointNotDetected = 1
    kVadInStateInSpeechSegment = 2
    kVadInStateEndPointDetected = 3
class FrameState(Enum):
    kFrameStateInvalid = -1
    kFrameStateSpeech = 1
    kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
    kChangeStateSpeech2Speech = 0
    kChangeStateSpeech2Sil = 1
    kChangeStateSil2Sil = 2
    kChangeStateSil2Speech = 3
    kChangeStateNoBegin = 4
    kChangeStateInvalid = 5
class VadDetectMode(Enum):
    kVadSingleUtteranceDetectMode = 0
    kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
    def __init__(
            self,
            sample_rate: int = 16000,
            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
            snr_mode: int = 0,
            max_end_silence_time: int = 800,
            max_start_silence_time: int = 3000,
            do_start_point_detection: bool = True,
            do_end_point_detection: bool = True,
            window_size_ms: int = 200,
            sil_to_speech_time_thres: int = 150,
            speech_to_sil_time_thres: int = 150,
            speech_2_noise_ratio: float = 1.0,
            do_extend: int = 1,
            lookback_time_start_point: int = 200,
            lookahead_time_end_point: int = 100,
            max_single_segment_time: int = 60000,
            nn_eval_block_size: int = 8,
            dcd_block_size: int = 4,
            snr_thres: int = -100.0,
            noise_frame_num_used_for_snr: int = 100,
            decibel_thres: int = -100.0,
            speech_noise_thres: float = 0.6,
            fe_prior_thres: float = 1e-4,
            silence_pdf_num: int = 1,
            sil_pdf_ids: List[int] = [0],
            speech_noise_thresh_low: float = -0.1,
            speech_noise_thresh_high: float = 0.3,
            output_frame_probs: bool = False,
            frame_in_ms: int = 10,
            frame_length_ms: int = 25,
    ):
        self.sample_rate = sample_rate
        self.detect_mode = detect_mode
        self.snr_mode = snr_mode
        self.max_end_silence_time = max_end_silence_time
        self.max_start_silence_time = max_start_silence_time
        self.do_start_point_detection = do_start_point_detection
        self.do_end_point_detection = do_end_point_detection
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time_thres = sil_to_speech_time_thres
        self.speech_to_sil_time_thres = speech_to_sil_time_thres
        self.speech_2_noise_ratio = speech_2_noise_ratio
        self.do_extend = do_extend
        self.lookback_time_start_point = lookback_time_start_point
        self.lookahead_time_end_point = lookahead_time_end_point
        self.max_single_segment_time = max_single_segment_time
        self.nn_eval_block_size = nn_eval_block_size
        self.dcd_block_size = dcd_block_size
        self.snr_thres = snr_thres
        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
        self.decibel_thres = decibel_thres
        self.speech_noise_thres = speech_noise_thres
        self.fe_prior_thres = fe_prior_thres
        self.silence_pdf_num = silence_pdf_num
        self.sil_pdf_ids = sil_pdf_ids
        self.speech_noise_thresh_low = speech_noise_thresh_low
        self.speech_noise_thresh_high = speech_noise_thresh_high
        self.output_frame_probs = output_frame_probs
        self.frame_in_ms = frame_in_ms
        self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
    def __init__(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
    def Reset(self):
        self.start_ms = 0
        self.end_ms = 0
        self.buffer = []
        self.contain_seg_start_point = False
        self.contain_seg_end_point = False
        self.doa = 0
class E2EVadFrameProb(object):
    def __init__(self):
        self.noise_prob = 0.0
        self.speech_prob = 0.0
        self.score = 0.0
        self.frame_id = 0
        self.frm_state = 0
class WindowDetector(object):
    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
                 speech_to_sil_time: int, frame_size_ms: int):
        self.window_size_ms = window_size_ms
        self.sil_to_speech_time = sil_to_speech_time
        self.speech_to_sil_time = speech_to_sil_time
        self.frame_size_ms = frame_size_ms
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
        self.noise_last_frame_count = 0
        self.hydre_frame_count = 0
    def GetWinSize(self) -> int:
        return int(self.win_size_frame)
    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
        cur_frame_state = FrameState.kFrameStateSil
        if frameState == FrameState.kFrameStateSpeech:
            cur_frame_state = 1
        elif frameState == FrameState.kFrameStateSil:
            cur_frame_state = 0
        else:
            return AudioChangeState.kChangeStateInvalid
        self.win_sum -= self.win_state[self.cur_win_pos]
        self.win_sum += cur_frame_state
        self.win_state[self.cur_win_pos] = cur_frame_state
        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSpeech
            return AudioChangeState.kChangeStateSil2Speech
        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
            self.pre_frame_state = FrameState.kFrameStateSil
            return AudioChangeState.kChangeStateSpeech2Sil
        if self.pre_frame_state == FrameState.kFrameStateSil:
            return AudioChangeState.kChangeStateSil2Sil
        if self.pre_frame_state == FrameState.kFrameStateSpeech:
            return AudioChangeState.kChangeStateSpeech2Speech
        return AudioChangeState.kChangeStateInvalid
    def FrameSizeMs(self) -> int:
        return int(self.frame_size_ms)
class E2EVadModel():
    def __init__(self, vad_post_args: Dict[str, Any]):
        super(E2EVadModel, self).__init__()
        self.vad_opts = VADXOptions(**vad_post_args)
        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                               self.vad_opts.sil_to_speech_time_thres,
                                               self.vad_opts.speech_to_sil_time_thres,
                                               self.vad_opts.frame_in_ms)
        # self.encoder = encoder
        # init variables
        self.is_final = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
        self.lastest_confirmed_silence_frame = -1
        self.continous_silence_frame_count = 0
        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.sil_frame = 0
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
        self.speech_noise_thres = self.vad_opts.speech_noise_thres
        self.scores = None
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
    def AllResetDetection(self):
        self.is_final = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
        self.lastest_confirmed_silence_frame = -1
        self.continous_silence_frame_count = 0
        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.sil_frame = 0
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
        self.speech_noise_thres = self.vad_opts.speech_noise_thres
        self.scores = None
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
    def ResetDetection(self):
        self.continous_silence_frame_count = 0
        self.latest_confirmed_speech_frame = 0
        self.lastest_confirmed_silence_frame = -1
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        self.windows_detector.Reset()
        self.sil_frame = 0
        self.frame_probs = []
    def ComputeDecibel(self) -> None:
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if self.data_buf_all is None:
            self.data_buf_all = self.waveform[0]  # self.data_buf is pointed to self.waveform[0]
            self.data_buf = self.data_buf_all
        else:
            self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0]))
        for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
            self.decibel.append(
                10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \
                                0.000001))
    def ComputeScores(self, scores: np.ndarray) -> None:
        # scores = self.encoder(feats, in_cache)  # return B * T * D
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        self.frm_cnt += scores.shape[1]  # count total frames
        if self.scores is None:
            self.scores = scores  # the first calculation
        else:
            self.scores = np.concatenate((self.scores, scores), axis=1)
    def PopDataBufTillFrame(self, frame_idx: int) -> None:  # need check again
        while self.data_buf_start_frame < frame_idx:
            if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                self.data_buf_start_frame += 1
                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                           last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
        self.PopDataBufTillFrame(start_frm)
        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
        if last_frm_is_end_point:
            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            expected_sample_number = max(expected_sample_number, len(self.data_buf))
        if len(self.data_buf) < expected_sample_number:
            print('error in calling pop data_buf\n')
        if len(self.output_data_buf) == 0 or first_frm_is_start_point:
            self.output_data_buf.append(E2EVadSpeechBufWithDoa())
            self.output_data_buf[-1].Reset()
            self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
            self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
            self.output_data_buf[-1].doa = 0
        cur_seg = self.output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning\n')
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
        else:
            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if data_to_pop > len(self.data_buf):
            print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
            data_to_pop = len(self.data_buf)
            expected_sample_number = len(self.data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
            out_pos += 1
        for sample_cpy_out in range(data_to_pop, expected_sample_number):
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('Something wrong with the VAD algorithm\n')
        self.data_buf_start_frame += frm_cnt
        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
        if first_frm_is_start_point:
            cur_seg.contain_seg_start_point = True
        if last_frm_is_end_point:
            cur_seg.contain_seg_end_point = True
    def OnSilenceDetected(self, valid_frame: int):
        self.lastest_confirmed_silence_frame = valid_frame
        if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataBufTillFrame(valid_frame)
        # silence_detected_callback_
        # pass
    def OnVoiceDetected(self, valid_frame: int) -> None:
        self.latest_confirmed_speech_frame = valid_frame
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
        if self.vad_opts.do_start_point_detection:
            pass
        if self.confirmed_start_frame != -1:
            print('not reset vad properly\n')
        else:
            self.confirmed_start_frame = start_frame
        if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
        for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
            self.OnVoiceDetected(t)
        if self.vad_opts.do_end_point_detection:
            pass
        if self.confirmed_end_frame != -1:
            print('not reset vad properly\n')
        else:
            self.confirmed_end_frame = end_frame
        if not fake_result:
            self.sil_frame = 0
            self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
        self.number_end_time_detected += 1
    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
        if is_final_frame:
            self.OnVoiceEnd(cur_frm_idx, False, True)
            self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
    def GetLatency(self) -> int:
        return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
    def LatencyFrmNumAtStartPoint(self) -> int:
        vad_latency = self.windows_detector.GetWinSize()
        if self.vad_opts.do_extend:
            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
        return vad_latency
    def GetFrameState(self, t: int) -> FrameState:
        frame_state = FrameState.kFrameStateInvalid
        cur_decibel = self.decibel[t]
        cur_snr = cur_decibel - self.noise_average_decibel
        # for each frame, calc log posterior probability of each state
        if cur_decibel < self.vad_opts.decibel_thres:
            frame_state = FrameState.kFrameStateSil
            self.DetectOneFrame(frame_state, t, False)
            return frame_state
        sum_score = 0.0
        noise_prob = 0.0
        assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(self.sil_pdf_ids) > 0:
            assert len(self.scores) == 1  # 只支持batch_size = 1的测试
            sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
            sum_score = sum(sil_pdf_scores)
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            total_score = 1.0
            sum_score = total_score - sum_score
        speech_prob = math.log(sum_score)
        if self.vad_opts.output_frame_probs:
            frame_prob = E2EVadFrameProb()
            frame_prob.noise_prob = noise_prob
            frame_prob.speech_prob = speech_prob
            frame_prob.score = sum_score
            frame_prob.frame_id = t
            self.frame_probs.append(frame_prob)
        if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
                frame_state = FrameState.kFrameStateSpeech
            else:
                frame_state = FrameState.kFrameStateSil
        else:
            frame_state = FrameState.kFrameStateSil
            if self.noise_average_decibel < -99.9:
                self.noise_average_decibel = cur_decibel
            else:
                self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
                        self.vad_opts.noise_frame_num_used_for_snr
                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
        return frame_state
    def __call__(self, score: np.ndarray, waveform: np.ndarray,
                is_final: bool = False, max_end_sil: int = 800, online: bool = False
                ):
        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(score)
        if not is_final:
            self.DetectCommonFrames()
        else:
            self.DetectLastFrames()
        segments = []
        for batch_num in range(0, score.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if online:
                        if not self.output_data_buf[i].contain_seg_start_point:
                            continue
                        if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
                            continue
                        start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
                        if self.output_data_buf[i].contain_seg_end_point:
                            end_ms = self.output_data_buf[i].end_ms
                            self.next_seg = True
                            self.output_data_buf_offset += 1
                        else:
                            end_ms = -1
                            self.next_seg = False
                    else:
                        if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
                            i].contain_seg_end_point):
                            continue
                        start_ms = self.output_data_buf[i].start_ms
                        end_ms = self.output_data_buf[i].end_ms
                        self.output_data_buf_offset += 1
                    segment = [start_ms, end_ms]
                    segment_batch.append(segment)
            if segment_batch:
                segments.append(segment_batch)
        if is_final:
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
        return segments
    def DetectCommonFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
        return 0
    def DetectLastFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            if i != 0:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
            else:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
        return 0
    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
        tmp_cur_frm_state = FrameState.kFrameStateInvalid
        if cur_frm_state == FrameState.kFrameStateSpeech:
            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
                tmp_cur_frm_state = FrameState.kFrameStateSpeech
            else:
                tmp_cur_frm_state = FrameState.kFrameStateSil
        elif cur_frm_state == FrameState.kFrameStateSil:
            tmp_cur_frm_state = FrameState.kFrameStateSil
        state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
        frm_shift_in_ms = self.vad_opts.frame_in_ms
        if AudioChangeState.kChangeStateSil2Speech == state_change:
            silence_frame_count = self.continous_silence_frame_count
            self.continous_silence_frame_count = 0
            self.pre_end_silence_detected = False
            start_frame = 0
            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
                self.OnVoiceStart(start_frame)
                self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
                for t in range(start_frame + 1, cur_frm_idx + 1):
                    self.OnVoiceDetected(t)
            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
                    self.OnVoiceDetected(t)
                if cur_frm_idx - self.confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False)
                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
            else:
                pass
        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
            self.continous_silence_frame_count = 0
            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                pass
            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - self.confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False)
                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
            else:
                pass
        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
            self.continous_silence_frame_count = 0
            if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if cur_frm_idx - self.confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.max_time_out = True
                    self.OnVoiceEnd(cur_frm_idx, False, False)
                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif not is_final_frame:
                    self.OnVoiceDetected(cur_frm_idx)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
            else:
                pass
        elif AudioChangeState.kChangeStateSil2Sil == state_change:
            self.continous_silence_frame_count += 1
            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
                # silence timeout, return zero length decision
                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
                        self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
                        or (is_final_frame and self.number_end_time_detected == 0):
                    for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
                        self.OnSilenceDetected(t)
                    self.OnVoiceStart(0, True)
                    self.OnVoiceEnd(0, True, False);
                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                else:
                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
                if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
                    lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
                    if self.vad_opts.do_extend:
                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
                        lookback_frame -= 1
                        lookback_frame = max(0, lookback_frame)
                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif cur_frm_idx - self.confirmed_start_frame + 1 > \
                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
                    self.OnVoiceEnd(cur_frm_idx, False, False)
                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
                elif self.vad_opts.do_extend and not is_final_frame:
                    if self.continous_silence_frame_count <= int(
                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
                        self.OnVoiceDetected(cur_frm_idx)
                else:
                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
            else:
                pass
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
            self.ResetDetection()
funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import copy
import numpy as np
from typeguard import check_argument_types
@@ -153,6 +154,187 @@
        cmvn = np.array([means, vars])
        return cmvn
class WavFrontendOnline(WavFrontend):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # self.fbank_fn = knf.OnlineFbank(self.opts)
        # add variables
        self.frame_sample_length = int(self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000)
        self.frame_shift_sample_length = int(self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000)
        self.waveform = None
        self.reserve_waveforms = None
        self.input_cache = None
        self.lfr_splice_cache = []
    @staticmethod
    # inputs has catted the cache
    def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
        np.ndarray, np.ndarray, int]:
        """
        Apply lfr with data
        """
        LFR_inputs = []
        T = inputs.shape[0]  # include the right context
        T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n))  # minus the right context: (lfr_m - 1) // 2
        splice_idx = T_lfr
        for i in range(T_lfr):
            if lfr_m <= T - i * lfr_n:
                LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
            else:  # process last LFR frame
                if is_final:
                    num_padding = lfr_m - (T - i * lfr_n)
                    frame = (inputs[i * lfr_n:]).reshape(-1)
                    for _ in range(num_padding):
                        frame = np.hstack((frame, inputs[-1]))
                    LFR_inputs.append(frame)
                else:
                    # update splice_idx and break the circle
                    splice_idx = i
                    break
        splice_idx = min(T - 1, splice_idx * lfr_n)
        lfr_splice_cache = inputs[splice_idx:, :]
        LFR_outputs = np.vstack(LFR_inputs)
        return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
    @staticmethod
    def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
        frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
        return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
    def fbank(
            self,
            input: np.ndarray,
            input_lengths: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        self.fbank_fn = knf.OnlineFbank(self.opts)
        batch_size = input.shape[0]
        if self.input_cache is None:
            self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
        input = np.concatenate((self.input_cache, input), axis=1)
        frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
        # update self.in_cache
        self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
        waveforms = np.empty(0, dtype=np.int16)
        feats_pad = np.empty(0, dtype=np.float32)
        feats_lens = np.empty(0, dtype=np.int32)
        if frame_num:
            waveforms = []
            feats = []
            feats_lens = []
            for i in range(batch_size):
                waveform = input[i]
                waveforms.append(
                    waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
                waveform = waveform * (1 << 15)
                self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
                frames = self.fbank_fn.num_frames_ready
                mat = np.empty([frames, self.opts.mel_opts.num_bins])
                for i in range(frames):
                    mat[i, :] = self.fbank_fn.get_frame(i)
                feat = mat.astype(np.float32)
                feat_len = np.array(mat.shape[0]).astype(np.int32)
                feats.append(mat)
                feats_lens.append(feat_len)
            waveforms = np.stack(waveforms)
            feats_lens = np.array(feats_lens)
            feats_pad = np.array(feats)
        self.fbanks = feats_pad
        self.fbanks_lens = copy.deepcopy(feats_lens)
        return waveforms, feats_pad, feats_lens
    def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
        return self.fbanks, self.fbanks_lens
    def lfr_cmvn(
            self,
            input: np.ndarray,
            input_lengths: np.ndarray,
            is_final: bool = False
    ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
        batch_size = input.shape[0]
        feats = []
        feats_lens = []
        lfr_splice_frame_idxs = []
        for i in range(batch_size):
            mat = input[i, :input_lengths[i], :]
            lfr_splice_frame_idx = -1
            if self.lfr_m != 1 or self.lfr_n != 1:
                # update self.lfr_splice_cache in self.apply_lfr
                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
                                                                                     is_final)
            if self.cmvn_file is not None:
                mat = self.apply_cmvn(mat)
            feat_length = mat.shape[0]
            feats.append(mat)
            feats_lens.append(feat_length)
            lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
        feats_lens = np.array(feats_lens)
        feats_pad = np.array(feats)
        return feats_pad, feats_lens, lfr_splice_frame_idxs
    def extract_fbank(
            self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        batch_size = input.shape[0]
        assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
        waveforms, feats, feats_lengths = self.fbank(input, input_lengths)  # input shape: B T D
        if feats.shape[0]:
            self.waveforms = waveforms if self.reserve_waveforms is None else np.concatenate(
                (self.reserve_waveforms, waveforms), axis=1)
            if not self.lfr_splice_cache:
                for i in range(batch_size):
                    self.lfr_splice_cache.append(np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0))
            if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
                lfr_splice_cache_np = np.stack(self.lfr_splice_cache)  # B T D
                feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
                feats_lengths += lfr_splice_cache_np[0].shape[0]
                frame_from_waveforms = int(
                    (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
                minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
                feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(feats, feats_lengths, is_final)
                if self.lfr_m == 1:
                    self.reserve_waveforms = None
                else:
                    reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
                    # print('reserve_frame_idx:  ' + str(reserve_frame_idx))
                    # print('frame_frame:  ' + str(frame_from_waveforms))
                    self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
                    sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
                    self.waveforms = self.waveforms[:, :sample_length]
            else:
                # update self.reserve_waveforms and self.lfr_splice_cache
                self.reserve_waveforms = self.waveforms[:,
                                         :-(self.frame_sample_length - self.frame_shift_sample_length)]
                for i in range(batch_size):
                    self.lfr_splice_cache[i] = np.concatenate((self.lfr_splice_cache[i], feats[i]), axis=0)
                return np.empty(0, dtype=np.float32), feats_lengths
        else:
            if is_final:
                self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
                feats = np.stack(self.lfr_splice_cache)
                feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
                feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
        if is_final:
            self.cache_reset()
        return feats, feats_lengths
    def get_waveforms(self):
        return self.waveforms
    def cache_reset(self):
        self.fbank_fn = knf.OnlineFbank(self.opts)
        self.reserve_waveforms = None
        self.input_cache = None
        self.lfr_splice_cache = []
def load_bytes(input):
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
@@ -188,4 +370,4 @@
    return feat, feat_len
if __name__ == '__main__':
    test()
    test()
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -24,21 +24,11 @@
                 ):
        check_argument_types()
        # self.token_list = self.load_token(token_path)
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
        self.token2id = {v: i for i, v in enumerate(self.token_list)}
        self.unk_id = self.token2id[self.unk_symbol]
    # @staticmethod
    # def load_token(file_path: Union[Path, str]) -> List:
    #     if not Path(file_path).exists():
    #         raise TokenIDConverterError(f'The {file_path} does not exist.')
    #
    #     with open(str(file_path), 'rb') as f:
    #         token_list = pickle.load(f)
    #
    #     if len(token_list) != len(set(token_list)):
    #         raise TokenIDConverterError('The Token exists duplicated symbol.')
    #     return token_list
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
@@ -51,13 +41,8 @@
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        token2id = {v: i for i, v in enumerate(self.token_list)}
        if self.unk_symbol not in token2id:
            raise TokenIDConverterError(
                f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
            )
        unk_id = token2id[self.unk_symbol]
        return [token2id.get(i, unk_id) for i in tokens]
        return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
@@ -188,7 +173,7 @@
                 input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
        input_dict = dict(zip(self.get_input_names(), input_content))
        try:
            return self.session.run(None, input_dict)
            return self.session.run(self.get_output_names(), input_dict)
        except Exception as e:
            raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
@@ -215,6 +200,38 @@
        if not model_path.is_file():
            raise FileExistsError(f'{model_path} is not a file.')
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
        return [words]
    sentences = []
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
    return sentences
def code_mix_split_words(text: str):
    words = []
    segs = text.split()
    for seg in segs:
        # There is no space in seg.
        current_word = ""
        for c in seg:
            if len(c.encode()) == 1:
                # This is an ASCII char.
                current_word += c
            else:
                # This is a Chinese char.
                if len(current_word) > 0:
                    words.append(current_word)
                    current_word = ""
                words.append(c)
        if len(current_word) > 0:
            words.append(current_word)
    return words
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():
@@ -226,7 +243,7 @@
@functools.lru_cache()
def get_logger(name='rapdi_paraformer'):
def get_logger(name='funasr_onnx'):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
New file
@@ -0,0 +1,280 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from .utils.utils import (ONNXRuntimeError,
                          OrtInferSession, get_logger,
                          read_yaml)
from .utils.frontend import WavFrontend, WavFrontendOnline
from .utils.e2e_vad import E2EVadModel
logging = get_logger()
class Fsmn_vad():
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 max_end_sil: int = None,
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        config_file = os.path.join(model_dir, 'vad.yaml')
        cmvn_file = os.path.join(model_dir, 'vad.mvn')
        config = read_yaml(config_file)
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.vad_scorer = E2EVadModel(config["vad_post_conf"])
        self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
        self.encoder_conf = config["encoder_conf"]
    def prepare_cache(self, in_cache: list = []):
        if len(in_cache) > 0:
            return in_cache
        fsmn_layers = self.encoder_conf["fsmn_layers"]
        proj_dim = self.encoder_conf["proj_dim"]
        lorder = self.encoder_conf["lorder"]
        for i in range(fsmn_layers):
            cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
            in_cache.append(cache)
        return in_cache
    def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        is_final = kwargs.get('kwargs', False)
        segments = [[]] * self.batch_size
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            waveform = waveform_list[beg_idx:end_idx]
            feats, feats_len = self.extract_feat(waveform)
            waveform = np.array(waveform)
            param_dict = kwargs.get('param_dict', dict())
            in_cache = param_dict.get('in_cache', list())
            in_cache = self.prepare_cache(in_cache)
            try:
                t_offset = 0
                step = int(min(feats_len.max(), 6000))
                for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
                    if t_offset + step >= feats_len - 1:
                        step = feats_len - t_offset
                        is_final = True
                    else:
                        is_final = False
                    feats_package = feats[:, t_offset:int(t_offset + step), :]
                    waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
                    inputs = [feats_package]
                    # inputs = [feats]
                    inputs.extend(in_cache)
                    scores, out_caches = self.infer(inputs)
                    in_cache = out_caches
                    segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False)
                    # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
                    if segments_part:
                        for batch_num in range(0, self.batch_size):
                            segments[batch_num] += segments_part[batch_num]
            except ONNXRuntimeError:
                # logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                segments = ''
        return segments
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
        if isinstance(wav_content, np.ndarray):
            return [wav_content]
        if isinstance(wav_content, str):
            return [load_wav(wav_content)]
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')
    def extract_feat(self,
                     waveform_list: List[np.ndarray]
                     ) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)
        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len
    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, 'constant', constant_values=0)
        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats
    def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer(feats)
        scores, out_caches = outputs[0], outputs[1:]
        return scores, out_caches
class Fsmn_vad_online():
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 max_end_sil: int = None,
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        config_file = os.path.join(model_dir, 'vad.yaml')
        cmvn_file = os.path.join(model_dir, 'vad.mvn')
        config = read_yaml(config_file)
        self.frontend = WavFrontendOnline(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.vad_scorer = E2EVadModel(config["vad_post_conf"])
        self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
        self.encoder_conf = config["encoder_conf"]
    def prepare_cache(self, in_cache: list = []):
        if len(in_cache) > 0:
            return in_cache
        fsmn_layers = self.encoder_conf["fsmn_layers"]
        proj_dim = self.encoder_conf["proj_dim"]
        lorder = self.encoder_conf["lorder"]
        for i in range(fsmn_layers):
            cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
            in_cache.append(cache)
        return in_cache
    def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
        waveforms = np.expand_dims(audio_in, axis=0)
        param_dict = kwargs.get('param_dict', dict())
        is_final = param_dict.get('is_final', False)
        feats, feats_len = self.extract_feat(waveforms, is_final)
        segments = []
        if feats.size != 0:
            in_cache = param_dict.get('in_cache', list())
            in_cache = self.prepare_cache(in_cache)
            try:
                inputs = [feats]
                inputs.extend(in_cache)
                scores, out_caches = self.infer(inputs)
                param_dict['in_cache'] = out_caches
                waveforms = self.frontend.get_waveforms()
                segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
                                           online=True)
            except ONNXRuntimeError:
                # logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                segments = []
        return segments
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
        if isinstance(wav_content, np.ndarray):
            return [wav_content]
        if isinstance(wav_content, str):
            return [load_wav(wav_content)]
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')
    def extract_feat(self,
                     waveforms: np.ndarray, is_final: bool = False
                     ) -> Tuple[np.ndarray, np.ndarray]:
        waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
        for idx, waveform in enumerate(waveforms):
            waveforms_lens[idx] = waveform.shape[-1]
        feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
        # feats.append(feat)
        # feats_len.append(feat_len)
        # feats = self.pad_feats(feats, np.max(feats_len))
        # feats_len = np.array(feats_len).astype(np.int32)
        return feats.astype(np.float32), feats_len.astype(np.int32)
    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, 'constant', constant_values=0)
        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats
    def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer(feats)
        scores, out_caches = outputs[0], outputs[1:]
        return scores, out_caches
funasr/runtime/python/onnxruntime/setup.py
@@ -13,14 +13,14 @@
MODULE_NAME = 'funasr_onnx'
VERSION_NUM = '0.0.2'
VERSION_NUM = '0.0.5'
setuptools.setup(
    name=MODULE_NAME,
    version=VERSION_NUM,
    platforms="Any",
    url="https://github.com/alibaba-damo-academy/FunASR.git",
    author="Speech Lab, Alibaba Group, China",
    author="Speech Lab of DAMO Academy, Alibaba Group",
    author_email="funasr@list.alibaba-inc.com",
    description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
    license='MIT',
funasr/runtime/python/utils/test_rtf_gpu.py
New file
@@ -0,0 +1,58 @@
import time
import sys
import librosa
from funasr.utils.types import str2bool
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True)
parser.add_argument('--backend', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--wav_file', type=str, default=None, help='amp fallback number')
parser.add_argument('--quantize', type=str2bool, default=False, help='quantized model')
parser.add_argument('--intra_op_num_threads', type=int, default=1, help='intra_op_num_threads for onnx')
parser.add_argument('--batch_size', type=int, default=1, help='batch_size for onnx')
args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
wav_files = wav_file_f.readlines()
# warm-up
total = 0.0
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
    beg_time = time.time()
    result = model(wav_path)
    end_time = time.time()
    duration = end_time-beg_time
    total += duration
    print(result)
    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
# infer time
wav_path = []
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
    wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    wav_path += [wav_path_i]
result = model(wav_path)
end_time = time.time()
duration = (end_time-beg_time)*1000
print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
    waveform, _ = librosa.load(wav_path, sr=16000)
    duration_time += len(waveform)/16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration/duration_time))
funasr/tasks/abs_task.py
@@ -464,6 +464,12 @@
            default=sys.maxsize,
            help="The maximum number update step to train",
        )
        parser.add_argument(
            "--batch_interval",
            type=int,
            default=10000,
            help="The batch interval for saving model.",
        )
        group.add_argument(
            "--patience",
            type=int_or_none,
@@ -1576,13 +1582,21 @@
    ) -> AbsIterFactory:
        assert check_argument_types()
        if hasattr(args, "frontend_conf"):
            if args.frontend_conf is not None and "fs" in args.frontend_conf:
                dest_sample_rate = args.frontend_conf["fs"]
            else:
                dest_sample_rate = 16000
        else:
            dest_sample_rate = 16000
        dataset = ESPnetDataset(
            iter_options.data_path_and_name_and_type,
            float_dtype=args.train_dtype,
            preprocess=iter_options.preprocess_fn,
            max_cache_size=iter_options.max_cache_size,
            max_cache_fd=iter_options.max_cache_fd,
            dest_sample_rate=args.frontend_conf["fs"],
            dest_sample_rate=dest_sample_rate,
        )
        cls.check_task_requirements(
            dataset, args.allow_variable_data_keys, train=iter_options.train
funasr/tasks/asr.py
@@ -412,12 +412,6 @@
            default="13_15",
            help="The range of noise decibel level.",
        )
        parser.add_argument(
            "--batch_interval",
            type=int,
            default=10000,
            help="The batch interval for saving model.",
        )
        for class_choices in cls.class_choices_list:
            # Append --<name> and --<name>_conf.
funasr/tasks/diar.py
@@ -1,3 +1,11 @@
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
import argparse
import logging
import os
funasr/tasks/lm.py
@@ -15,7 +15,7 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.lm.abs_model import AbsLM
from funasr.lm.espnet_model import ESPnetLanguageModel
from funasr.lm.abs_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
@@ -83,7 +83,7 @@
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(ESPnetLanguageModel),
            default=get_default_kwargs(LanguageModel),
            help="The keyword arguments for model class.",
        )
@@ -178,7 +178,7 @@
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
    def build_model(cls, args: argparse.Namespace) -> LanguageModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
@@ -201,7 +201,7 @@
        # 2. Build ESPnetModel
        # Assume the last-id is sos_and_eos
        model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
        model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
        # 3. Initialize
        if args.init is not None:
funasr/tasks/punctuation.py
@@ -14,10 +14,10 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.punctuation.espnet_model import ESPnetPunctuationModel
from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@@ -79,7 +79,7 @@
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(ESPnetPunctuationModel),
            default=get_default_kwargs(PunctuationModel),
            help="The keyword arguments for model class.",
        )
@@ -183,7 +183,7 @@
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
    def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
@@ -218,7 +218,7 @@
        # Assume the last-id is sos_and_eos
        if "punc_weight" in args.model_conf:
            args.model_conf.pop("punc_weight")
        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
        model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
        # FIXME(kamo): Should be done in model?
        # 3. Initialize
funasr/tasks/sv.py
@@ -1,3 +1,7 @@
"""
Author: Speech Lab, Alibaba Group, China
"""
import argparse
import logging
import os
funasr/tasks/vad.py
@@ -40,7 +40,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
@@ -81,6 +81,7 @@
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
        wav_frontend_online=WavFrontendOnline,
    ),
    type_check=AbsFrontend,
    default="default",
@@ -291,7 +292,24 @@
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("e2evad")
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend':
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
        return model
@@ -302,6 +320,7 @@
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            device: str = "cpu",
            cmvn_file: Union[Path, str] = None,
    ):
        """Build model from the files.
@@ -325,6 +344,8 @@
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        #if cmvn_file is not None:
        args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        model.to(device)
funasr/train/abs_model.py
File was renamed from funasr/punctuation/espnet_model.py
@@ -1,3 +1,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -7,13 +11,34 @@
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class ESPnetPunctuationModel(AbsESPnetModel):
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
    """The abstract class
    To share the loss calculation way among different models,
    We uses delegate pattern here:
    The instance of this class should be passed to "LanguageModel"
    This "model" is one of mediator objects for "Task" class.
    """
    @abstractmethod
    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError
    @abstractmethod
    def with_vad(self) -> bool:
        raise NotImplementedError
class PunctuationModel(AbsESPnetModel):
    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
        assert check_argument_types()
        super().__init__()
@@ -21,12 +46,12 @@
        self.punc_weight = torch.Tensor(punc_weight)
        self.sos = 1
        self.eos = 2
        # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
        self.ignore_id = ignore_id
        #if self.punc_model.with_vad():
        # if self.punc_model.with_vad():
        #    print("This is a vad puncuation model.")
    def nll(
        self,
        text: torch.Tensor,
@@ -54,7 +79,7 @@
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]
        if self.punc_model.with_vad():
            # Should be VadRealtimeTransformer
            assert vad_indexes is not None
@@ -62,7 +87,7 @@
        else:
            # Should be TargetDelayTransformer,
            y, _ = self.punc_model(text, text_lengths)
        # Calc negative log likelihood
        # nll: (BxL,)
        if self.training == False:
@@ -75,7 +100,8 @@
            return nll, text_lengths
        else:
            self.punc_weight = self.punc_weight.to(punc.device)
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
                                  ignore_index=self.ignore_id)
        # nll: (BxL,) -> (BxL,)
        if max_length is None:
            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +113,7 @@
        # nll: (BxL,) -> (B, L)
        nll = nll.view(batch_size, -1)
        return nll, text_lengths
    def batchify_nll(self,
                     text: torch.Tensor,
                     punc: torch.Tensor,
@@ -113,7 +139,7 @@
            nlls = []
            x_lengths = []
            max_length = text_lengths.max()
            start_idx = 0
            while True:
                end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +158,7 @@
        assert nll.size(0) == total_num
        assert x_lengths.size(0) == total_num
        return nll, x_lengths
    def forward(
        self,
        text: torch.Tensor,
@@ -146,15 +172,15 @@
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
    def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
                      text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {}
    def inference(self,
                  text: torch.Tensor,
                  text_lengths: torch.Tensor,
funasr/utils/compute_wer.py
@@ -45,8 +45,8 @@
           if out_item['wrong'] > 0:
               rst['wrong_sentences'] += 1
           cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
           cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
           cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
           cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
           cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
    if rst['Wrd'] > 0:
        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
funasr/version.txt
@@ -1 +1 @@
0.3.2
0.4.1
setup.py
@@ -123,7 +123,7 @@
    name="funasr",
    version=version,
    url="https://github.com/alibaba-damo-academy/FunASR.git",
    author="Speech Lab, Alibaba Group, China",
    author="Speech Lab of DAMO Academy, Alibaba Group",
    author_email="funasr@list.alibaba-inc.com",
    description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
    long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
tests/test_asr_inference_pipeline.py
@@ -43,6 +43,7 @@
        rec_result = inference_pipeline(
            audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
        logger.info("asr inference result: {0}".format(rec_result))
        assert rec_result["text"] == "每一天都要快乐喔"
    def test_paraformer(self):
        inference_pipeline = pipeline(
@@ -51,6 +52,7 @@
        rec_result = inference_pipeline(
            audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
        logger.info("asr inference result: {0}".format(rec_result))
        assert rec_result["text"] == "每一天都要快乐喔"
class TestMfccaInferencePipelines(unittest.TestCase):
tests/test_punctuation_pipeline.py
@@ -26,16 +26,14 @@
        inference_pipeline = pipeline(
            task=Tasks.punctuation,
            model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
            model_revision="v1.0.0",
        )
        inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
        vads = inputs.split("|")
        cache_out = []
        rec_result_all = "outputs:"
        param_dict = {"cache": []}
        for vad in vads:
            rec_result = inference_pipeline(text_in=vad, cache=cache_out)
            cache_out = rec_result['cache']
            rec_result_all += rec_result['text']
            rec_result = inference_pipeline(text_in=vad, param_dict=param_dict)
            rec_result_all += rec_result["text"]
        logger.info("punctuation inference result: {0}".format(rec_result_all))