From 937e507977cc9e49ce323f8b2933087d0fe52698 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 16 四月 2023 22:29:32 +0800
Subject: [PATCH] Merge pull request #363 from alibaba-damo-academy/main
---
funasr/runtime/onnxruntime/readme.md | 99
funasr/bin/punctuation_infer.py | 2
funasr/runtime/grpc/Readme.md | 62
funasr/bin/lm_inference_launch.py | 3
funasr/export/models/e2e_vad.py | 60
funasr/tasks/vad.py | 25
funasr/runtime/onnxruntime/include/libfunasrapi.h | 77
funasr/tasks/punctuation.py | 14
funasr/bin/asr_inference_paraformer.py | 2
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py | 59
funasr/bin/asr_inference_uniasr_vad.py | 4
funasr/runtime/python/onnxruntime/demo_punc_online.py | 15
funasr/models/encoder/opennmt_encoders/conv_encoder.py | 2
funasr/export/test/test_torchscripts.py | 0
funasr/runtime/onnxruntime/src/precomp.h | 3
tests/test_asr_inference_pipeline.py | 2
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py | 2
funasr/models/encoder/resnet34_encoder.py | 12
funasr/bin/sv_inference_launch.py | 3
funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py | 616 +++++++
funasr/models/encoder/sanm_encoder.py | 236 ++
funasr/tasks/diar.py | 8
funasr/tasks/lm.py | 8
funasr/bin/punc_inference_launch.py | 3
funasr/tasks/abs_task.py | 16
funasr/export/models/CT_Transformer.py | 162 +
funasr/runtime/python/onnxruntime/demo_vad_online.py | 28
funasr/bin/asr_inference_uniasr.py | 7
funasr/models/e2e_vad.py | 35
funasr/runtime/onnxruntime/src/paraformer_onnx.h | 19
egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py | 6
funasr/models/e2e_diar_sond.py | 8
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh | 4
funasr/models/e2e_tp.py | 2
funasr/export/models/e2e_asr_paraformer.py | 4
funasr/train/abs_model.py | 54
funasr/bin/asr_inference_rnnt.py | 4
funasr/models/decoder/sanm_decoder.py | 17
funasr/export/test/test_onnx.py | 0
funasr/export/models/__init__.py | 18
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py | 2
funasr/models/decoder/transformer_decoder.py | 2
funasr/runtime/python/onnxruntime/demo.py | 2
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py | 2
funasr/bin/asr_inference_paraformer_streaming.py | 17
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py | 2
funasr/datasets/dataset.py | 2
funasr/runtime/onnxruntime/CMakeLists.txt | 19
funasr/runtime/grpc/paraformer_server.h | 12
funasr/bin/vad_inference_launch.py | 3
funasr/models/e2e_asr_paraformer.py | 4
funasr/export/README.md | 2
funasr/bin/diar_inference_launch.py | 3
funasr/export/test/test_onnx_punc_vadrealtime.py | 22
funasr/runtime/onnxruntime/include/Audio.h | 17
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py | 2
funasr/runtime/python/onnxruntime/README.md | 15
funasr/runtime/python/grpc/grpc_server.py | 2
funasr/runtime/onnxruntime/src/CMakeLists.txt | 29
funasr/export/models/encoder/fsmn_encoder.py | 296 +++
funasr/models/e2e_sv.py | 4
setup.py | 2
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp | 52
funasr/runtime/onnxruntime/src/resample.h | 137 +
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 259 +++
funasr/export/export_model.py | 52
funasr/runtime/python/libtorch/README.md | 13
funasr/runtime/onnxruntime/tester/tester_rtf.cpp | 18
funasr/export/test/test_onnx_vad.py | 26
funasr/runtime/onnxruntime/src/Audio.cpp | 262 ++
funasr/runtime/onnxruntime/src/FeatureExtract.cpp | 28
funasr/runtime/onnxruntime/src/FeatureExtract.h | 13
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 6
funasr/bin/punctuation_infer_vadrealtime.py | 4
funasr/bin/tp_inference_launch.py | 3
funasr/bin/asr_inference_paraformer_vad_punc.py | 2
funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py | 6
funasr/runtime/grpc/paraformer_server.cc | 73
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py | 280 +++
funasr/runtime/onnxruntime/tester/CMakeLists.txt | 2
funasr/bin/asr_inference_paraformer_vad.py | 2
funasr/export/models/encoder/sanm_encoder.py | 104 +
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py | 2
README.md | 37
funasr/models/vad_realtime_transformer.py | 10
funasr/runtime/onnxruntime/src/Vocab.cpp | 15
funasr/export/test/__init__.py | 0
docs/modelscope_models.md | 93
funasr/datasets/large_datasets/utils/tokenize.py | 4
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py | 15
funasr/version.txt | 2
funasr/models/e2e_asr_mfcca.py | 6
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py | 93
funasr/runtime/python/grpc/grpc_main_client.py | 62
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py | 2
funasr/runtime/onnxruntime/tester/tester.cpp | 58
funasr/runtime/onnxruntime/src/resample.cc | 305 +++
funasr/runtime/onnxruntime/src/commonfunc.h | 4
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py | 2
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py | 19
funasr/models/target_delay_transformer.py | 11
funasr/datasets/preprocessor.py | 15
funasr/runtime/onnxruntime/src/libfunasrapi.cpp | 184 ++
funasr/models/e2e_uni_asr.py | 2
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py | 4
funasr/runtime/python/onnxruntime/demo_vad_offline.py | 11
funasr/lm/abs_model.py | 129 +
funasr/runtime/python/onnxruntime/setup.py | 4
funasr/runtime/python/utils/test_rtf_gpu.py | 58
funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py | 184 ++
funasr/tasks/sv.py | 4
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh | 4
funasr/models/decoder/contextual_decoder.py | 2
funasr/runtime/python/onnxruntime/demo_punc_offline.py | 8
funasr/export/test/test_onnx_punc.py | 18
funasr/tasks/asr.py | 6
funasr/runtime/python/libtorch/setup.py | 6
funasr/models/frontend/wav_frontend.py | 12
funasr/modules/streaming_utils/chunk_utilis.py | 2
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py | 5
/dev/null | 182 --
funasr/runtime/grpc/CMakeLists.txt | 2
egs/aishell/transformer/utils/compute_wer.py | 4
funasr/utils/compute_wer.py | 4
funasr/models/predictor/cif.py | 12
tests/test_punctuation_pipeline.py | 8
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py | 2
funasr/bin/asr_inference_launch.py | 3
128 files changed, 4,324 insertions(+), 795 deletions(-)
diff --git a/README.md b/README.md
index 23f1abe..af38341 100644
--- a/README.md
+++ b/README.md
@@ -7,12 +7,11 @@
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
| [**Highlights**](#highlights)
| [**Installation**](#installation)
-| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
-| [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
+| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md)
| [**Contact**](#contact)
@@ -29,15 +28,37 @@
## Installation
-``` sh
-pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
-git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install --editable ./
+Install from pip
+```shell
+pip install -U funasr
+# For the users in China, you could install with the command:
+# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
+
+Or install from source code
+
+
+``` sh
+git clone https://github.com/alibaba/FunASR.git && cd FunASR
+pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+
+```
+If you want to use the pretrained models in ModelScope, you should install the modelscope:
+
+```shell
+pip install -U modelscope
+# For the users in China, you could install with the command:
+# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
-## Usage
-For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
+[//]: # ()
+[//]: # (## Usage)
+
+[//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)))
## Contact
diff --git a/docs/modelscope_models.md b/docs/modelscope_models.md
index 277d8e9..be9a4f8 100644
--- a/docs/modelscope_models.md
+++ b/docs/modelscope_models.md
@@ -6,29 +6,70 @@
## Model Zoo
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
-| Datasets | Hours | Model | Online/Offline | Language | Framework | Checkpoint |
-|:-----:|:-----:|:--------------:|:--------------:| :---: | :---: | --- |
-| Alibaba Speech Data | 60000 | Paraformer | Offline | CN | Pytorch |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) |
-| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
-| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
-| Alibaba Speech Data | 50000 | Paraformer | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online](http://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Offline | CN | Tensorflow |[speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Online | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Offline | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 20000 | UniASR | Online | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 20000 | UniASR | Offline | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
-| AISHELL-1 | 178 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) |
-| AISHELL-2 | 1000 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell2-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell2-pytorch/summary) |
-| AISHELL-1 | 178 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
-| AISHELL-2 | 1000 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
-| AISHELL-1 | 178 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
-| AISHELL-2 | 1000 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
+### Speech Recognition Models
+#### Paraformer Models
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
+| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
+| [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
+| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
+| [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
+| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
+| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+
+
+#### UniASR Models
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
+| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
+| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
+
+#### Conformer Models
+#### Paraformer Models
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
+| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
+
+#### RNN-T Models
+
+### Voice Activity Detection Models
+
+| Model Name | Training Data | Parameters | Sampling Rate | Notes |
+|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
+
+### Punctuation Restoration Models
+
+| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
+
+### Language Models
+
+| Model Name | Training Data | Parameters | Vocab Size | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
+| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
+
+### Speaker Verification Models
+
+| Model Name | Training Data | Parameters | Vocab Size | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (?hours) | 17.5M | 3465 | |
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (?hours) | 61M | 6135 | |
+
+### Speaker diarization Models
+
+| Model Name | Training Data | Parameters | Notes |
+|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (?hours) | 40.5M | |
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (?hours) | 12M | |
diff --git a/egs/aishell/transformer/utils/compute_wer.py b/egs/aishell/transformer/utils/compute_wer.py
index 349a3f6..26a9f49 100755
--- a/egs/aishell/transformer/utils/compute_wer.py
+++ b/egs/aishell/transformer/utils/compute_wer.py
@@ -45,8 +45,8 @@
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
- cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
- cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+ cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+ cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
index c016c19..77b2cbd 100644
--- a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
@@ -74,7 +74,7 @@
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "token")
+ text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
index b326067..488936c 100644
--- a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
@@ -38,7 +38,7 @@
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
index 54cfec0..0d06377 100644
--- a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer.py
@@ -74,7 +74,7 @@
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "token")
+ text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py
index 2f038a8..c94f685 100644
--- a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/infer_after_finetune.py
@@ -38,7 +38,7 @@
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
index f080257..221479d 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer.sh
@@ -63,8 +63,8 @@
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
echo "Computing WER ..."
- python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
- python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
+ cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
+ cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
tail -n 3 ${output_dir}/1best_recog/text.cer
fi
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
index 295c95d..2d311dd 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/infer_after_finetune.py
@@ -34,7 +34,7 @@
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh
index cdf81dc..6daf7d4 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer.sh
@@ -63,8 +63,8 @@
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
echo "Computing WER ..."
- python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
- python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
+ cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
+ cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
tail -n 3 ${output_dir}/1best_recog/text.cer
fi
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py
index e8fee02..747b49f 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/infer_after_finetune.py
@@ -34,7 +34,7 @@
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
index 5d74837..ce8988e 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
@@ -23,8 +23,7 @@
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
-
+ inference_pipline(audio_in=audio_in)
def modelscope_infer(params):
# prepare for multi-GPU decoding
@@ -75,7 +74,7 @@
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "token")
+ text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
index 861fefb..1e9c4d1 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
@@ -2,52 +2,103 @@
import os
import shutil
+from multiprocessing import Pool
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
+def modelscope_infer_after_finetune_core(model_dir, output_dir, split_dir, njob, idx):
+ output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+ gpu_id = (int(idx) - 1) // njob
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=model_dir,
+ output_dir=output_dir_job,
+ batch_size=1
+ )
+ audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+ inference_pipeline(audio_in=audio_in)
+
def modelscope_infer_after_finetune(params):
- # prepare for decoding
+ # prepare for multi-GPU decoding
+ model_dir = params["model_dir"]
pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
for file_name in params["required_files"]:
if file_name == "configuration.json":
with open(os.path.join(pretrained_model_path, file_name)) as f:
config_dict = json.load(f)
config_dict["model"]["am_model_name"] = params["decoding_model_name"]
- with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+ with open(os.path.join(model_dir, "configuration.json"), "w") as f:
json.dump(config_dict, f, indent=4, separators=(',', ': '))
else:
shutil.copy(os.path.join(pretrained_model_path, file_name),
- os.path.join(params["output_dir"], file_name))
- decoding_path = os.path.join(params["output_dir"], "decode_results")
- if os.path.exists(decoding_path):
- shutil.rmtree(decoding_path)
- os.mkdir(decoding_path)
+ os.path.join(model_dir, file_name))
+ ngpu = params["ngpu"]
+ njob = params["njob"]
+ output_dir = params["output_dir"]
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+ split_dir = os.path.join(output_dir, "split")
+ os.mkdir(split_dir)
+ nj = ngpu * njob
+ wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
- # decoding
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=params["output_dir"],
- output_dir=decoding_path,
- batch_size=1
- )
- audio_in = os.path.join(params["data_dir"], "wav.scp")
- inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(modelscope_infer_after_finetune_core,
+ args=(model_dir, output_dir, split_dir, njob, str(i + 1)))
+ p.close()
+ p.join()
- # computer CER if GT text is set
+ # combine decoding results
+ best_recog_path = os.path.join(output_dir, "1best_recog")
+ os.mkdir(best_recog_path)
+ files = ["text", "token", "score"]
+ for file in files:
+ with open(os.path.join(best_recog_path, file), "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+ with open(job_file) as f_job:
+ lines = f_job.readlines()
+ f.writelines(lines)
+
+ # If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
- compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
-
+ text_proc_file = os.path.join(best_recog_path, "token")
+ compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
if __name__ == '__main__':
params = {}
params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline"
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
- params["output_dir"] = "./checkpoint"
+ params["model_dir"] = "./checkpoint"
+ params["output_dir"] = "./results"
params["data_dir"] = "./data/test"
params["decoding_model_name"] = "20epoch.pb"
+ params["ngpu"] = 1
+ params["njob"] = 1
modelscope_infer_after_finetune(params)
+
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
index 5c62362..8b4a04d 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
@@ -75,7 +75,7 @@
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "token")
+ text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
index d73cae2..fd124ff 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
@@ -39,7 +39,7 @@
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
similarity index 87%
rename from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
rename to egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
index 5f4563d..3db6f7d 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
@@ -1,3 +1,9 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+https://arxiv.org/abs/2303.05397
+"""
+
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
similarity index 60%
copy from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
copy to egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
index 5f4563d..db10193 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
@@ -1,3 +1,9 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+https://arxiv.org/abs/2211.10243
+"""
+
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -8,17 +14,18 @@
num_workers=0,
task=Tasks.speaker_diarization,
diar_model_config="sond.yaml",
- model='damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch',
- sv_model="damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch",
+ model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
+ sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
sv_model_revision="master",
)
# 浠� audio_list 浣滀负杈撳叆锛屽叾涓涓�涓煶棰戜负寰呮娴嬭闊筹紝鍚庨潰鐨勯煶棰戜负涓嶅悓璇磋瘽浜虹殑澹扮汗娉ㄥ唽璇煶
audio_list = [
- "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
- "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
- "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav",
- "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav"
+ "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav",
+ "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav",
+ "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav",
+ "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav",
+ "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav",
]
results = inference_diar_pipline(audio_in=audio_list)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index da1241a..7add960 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 2eeffcd..8cbd419 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -797,7 +797,7 @@
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
- ibest_writer["text"][key] = text_postprocessed
+ ibest_writer["text"][key] = " ".join(word_lists)
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index 66dec39..3520297 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
+import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -607,17 +608,21 @@
):
# 3. Build data-iterator
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
- raw_inputs = _load_bytes(data_path_and_name_and_type[0])
- raw_inputs = torch.tensor(raw_inputs)
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, np.ndarray):
- raw_inputs = torch.tensor(raw_inputs)
is_final = False
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
if param_dict is not None and "is_final" in param_dict:
is_final = param_dict["is_final"]
+
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
+ raw_inputs = _load_bytes(data_path_and_name_and_type[0])
+ raw_inputs = torch.tensor(raw_inputs)
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ is_final = True
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, np.ndarray):
+ raw_inputs = torch.tensor(raw_inputs)
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index a0dc0aa..1548f9f 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -338,7 +338,7 @@
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
- ibest_writer["text"][key] = text_postprocessed
+ ibest_writer["text"][key] = " ".join(word_lists)
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index ab3e1e3..9dc0b79 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -670,7 +670,7 @@
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
- ibest_writer["text"][key] = text_postprocessed
+ ibest_writer["text"][key] = " ".join(word_lists)
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 4a9ff0b..2189a71 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -738,13 +738,13 @@
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
- ibest_writer["text"][key] = text_postprocessed
+ ibest_writer["text"][key] = " ".join(word_lists)
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index 7961d5a..4aea720 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -37,9 +37,6 @@
from funasr.models.frontend.wav_frontend import WavFrontend
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
class Speech2Text:
"""Speech2Text class
@@ -507,13 +504,13 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
- ibest_writer["text"][key] = text_postprocessed
+ ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
return _forward
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index 3164d0d..52c29b8 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -507,13 +507,13 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
- ibest_writer["text"][key] = text_postprocessed
+ ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
return _forward
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 85e4518..83436e8 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index 492ebab..d229cc6 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index e7e3f15..2c5a286 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/punc_train_vadrealtime.py b/funasr/bin/punc_train_vadrealtime.py
deleted file mode 100644
index c5afaad..0000000
--- a/funasr/bin/punc_train_vadrealtime.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#!/usr/bin/env python3
-import os
-from funasr.tasks.punctuation import PunctuationTask
-
-
-def parse_args():
- parser = PunctuationTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- parser.add_argument(
- "--punc_list",
- type=str,
- default=None,
- help="Punctuation list",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- """
- punc training.
- """
- PunctuationTask.main(args=args, cmd=cmd)
-
-
-if __name__ == "__main__":
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- main(args=args)
diff --git a/funasr/bin/punctuation_infer.py b/funasr/bin/punctuation_infer.py
index a801ee8..dd28ef8 100644
--- a/funasr/bin/punctuation_infer.py
+++ b/funasr/bin/punctuation_infer.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index ce1cee8..5157eeb 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@@ -90,7 +90,7 @@
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+ "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 1205d19..64a3cff 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index dd76df6..55debac 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index 42c5c1e..8fea8db 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/datasets/dataset.py b/funasr/datasets/dataset.py
index 1595224..979479c 100644
--- a/funasr/datasets/dataset.py
+++ b/funasr/datasets/dataset.py
@@ -115,7 +115,7 @@
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
# like Kaldi e.g. "cat a.wav |".
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
- loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+ loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
# but ndarray is desired, so Adapter class is inserted here
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index d8ceff2..5a2f921 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -47,8 +47,8 @@
length = len(text)
for i in range(length):
x = text[i]
- if i == length-1 and "punc" in data and text[i].startswith("vad:"):
- vad = x[-1][4:]
+ if i == length-1 and "punc" in data and x.startswith("vad:"):
+ vad = x[4:]
if len(vad) == 0:
vad = -1
else:
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 98cca1d..1adca05 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -786,6 +786,7 @@
) -> Dict[str, np.ndarray]:
for i in range(self.num_tokenizer):
text_name = self.text_name[i]
+ #import pdb; pdb.set_trace()
if text_name in data and self.tokenizer[i] is not None:
text = data[text_name]
text = self.text_cleaner(text)
@@ -800,3 +801,17 @@
data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
+ return data
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
diff --git a/funasr/export/README.md b/funasr/export/README.md
index 97a3de9..4d09ff8 100644
--- a/funasr/export/README.md
+++ b/funasr/export/README.md
@@ -7,7 +7,7 @@
## Install modelscope and funasr
-The installation is the same as [funasr](../../README.md)
+The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
## Export model
`Tips`: torch>=1.11.0
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index d3d119c..b69eeee 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -167,31 +167,57 @@
def export(self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- mode: str = 'paraformer',
+ mode: str = None,
):
model_dir = tag_name
- if model_dir.startswith('damo/'):
+ if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
- asr_train_config = os.path.join(model_dir, 'config.yaml')
- asr_model_file = os.path.join(model_dir, 'model.pb')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
- json_file = os.path.join(model_dir, 'configuration.json')
+
if mode is None:
import json
+ json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
- mode = config_data['model']['model_config']['mode']
+ if config_data['task'] == "punctuation":
+ mode = config_data['model']['punc_model_config']['mode']
+ else:
+ mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- elif mode.startswith('uniasr'):
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+ config = os.path.join(model_dir, 'config.yaml')
+ model_file = os.path.join(model_dir, 'model.pb')
+ cmvn_file = os.path.join(model_dir, 'am.mvn')
+ model, asr_train_args = ASRTask.build_model_from_file(
+ config, model_file, cmvn_file, 'cpu'
+ )
+ self.frontend = model.frontend
+ elif mode.startswith('offline'):
+ from funasr.tasks.vad import VADTask
+ config = os.path.join(model_dir, 'vad.yaml')
+ model_file = os.path.join(model_dir, 'vad.pb')
+ cmvn_file = os.path.join(model_dir, 'vad.mvn')
- model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, 'cpu'
- )
- self.frontend = model.frontend
+ model, vad_infer_args = VADTask.build_model_from_file(
+ config, model_file, cmvn_file=cmvn_file, device='cpu'
+ )
+ self.export_config["feats_dim"] = 400
+ self.frontend = model.frontend
+ elif mode.startswith('punc'):
+ from funasr.tasks.punctuation import PunctuationTask as PUNCTask
+ punc_train_config = os.path.join(model_dir, 'config.yaml')
+ punc_model_file = os.path.join(model_dir, 'punc.pb')
+ model, punc_train_args = PUNCTask.build_model_from_file(
+ punc_train_config, punc_model_file, 'cpu'
+ )
+ elif mode.startswith('punc_VadRealtime'):
+ from funasr.tasks.punctuation import PunctuationTask as PUNCTask
+ punc_train_config = os.path.join(model_dir, 'config.yaml')
+ punc_model_file = os.path.join(model_dir, 'punc.pb')
+ model, punc_train_args = PUNCTask.build_model_from_file(
+ punc_train_config, punc_model_file, 'cpu'
+ )
self._export(model, tag_name)
diff --git a/funasr/export/models/CT_Transformer.py b/funasr/export/models/CT_Transformer.py
new file mode 100644
index 0000000..932e3af
--- /dev/null
+++ b/funasr/export/models/CT_Transformer.py
@@ -0,0 +1,162 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.models.encoder.sanm_encoder import SANMEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
+from funasr.models.encoder.sanm_encoder import SANMVadEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
+
+class CT_Transformer(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ model_name='punc_model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+ self.embed = model.embed
+ self.decoder = model.decoder
+ # self.model = model
+ self.feats_dim = self.embed.embedding_dim
+ self.num_embeddings = self.embed.num_embeddings
+ self.model_name = model_name
+
+ if isinstance(model.encoder, SANMEncoder):
+ self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+ else:
+ assert False, "Only support samn encode."
+
+ def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ """Compute loss value from buffer sequences.
+
+ Args:
+ input (torch.Tensor): Input ids. (batch, len)
+ hidden (torch.Tensor): Target ids. (batch, len)
+
+ """
+ x = self.embed(inputs)
+ # mask = self._target_mask(input)
+ h, _ = self.encoder(x, text_lengths)
+ y = self.decoder(h)
+ return y
+
+ def get_dummy_inputs(self):
+ length = 120
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
+ text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+ return (text_indexes, text_lengths)
+
+ def get_input_names(self):
+ return ['inputs', 'text_lengths']
+
+ def get_output_names(self):
+ return ['logits']
+
+ def get_dynamic_axes(self):
+ return {
+ 'inputs': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'text_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
+
+
+class CT_Transformer_VadRealtime(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ model_name='punc_model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+
+ self.embed = model.embed
+ if isinstance(model.encoder, SANMVadEncoder):
+ self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
+ else:
+ assert False, "Only support samn encode."
+ self.decoder = model.decoder
+ self.model_name = model_name
+
+
+
+ def forward(self, inputs: torch.Tensor,
+ text_lengths: torch.Tensor,
+ vad_indexes: torch.Tensor,
+ sub_masks: torch.Tensor,
+ ) -> Tuple[torch.Tensor, None]:
+ """Compute loss value from buffer sequences.
+
+ Args:
+ input (torch.Tensor): Input ids. (batch, len)
+ hidden (torch.Tensor): Target ids. (batch, len)
+
+ """
+ x = self.embed(inputs)
+ # mask = self._target_mask(input)
+ h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
+ y = self.decoder(h)
+ return y
+
+ def with_vad(self):
+ return True
+
+ def get_dummy_inputs(self):
+ length = 120
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
+ text_lengths = torch.tensor([length], dtype=torch.int32)
+ vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
+ sub_masks = torch.ones(length, length, dtype=torch.float32)
+ sub_masks = torch.tril(sub_masks).type(torch.float32)
+ return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
+
+ def get_input_names(self):
+ return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
+
+ def get_output_names(self):
+ return ['logits']
+
+ def get_dynamic_axes(self):
+ return {
+ 'inputs': {
+ 1: 'feats_length'
+ },
+ 'vad_masks': {
+ 2: 'feats_length1',
+ 3: 'feats_length2'
+ },
+ 'sub_masks': {
+ 2: 'feats_length1',
+ 3: 'feats_length2'
+ },
+ 'logits': {
+ 1: 'logits_length'
+ },
+ }
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 0012377..0e3a782 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,13 +1,25 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
-from funasr.models.e2e_uni_asr import UniASR
-
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
+from funasr.train.abs_model import PunctuationModel
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
+ elif isinstance(model, E2EVadModel):
+ return E2EVadModel_export(model, **export_config)
+ elif isinstance(model, PunctuationModel):
+ if isinstance(model.punc_model, TargetDelayTransformer):
+ return CT_Transformer_export(model.punc_model, **export_config)
+ elif isinstance(model.punc_model, VadRealtimeTransformer):
+ return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
else:
- raise "Funasr does not support the given model type currently."
\ No newline at end of file
+ raise "Funasr does not support the given model type currently."
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 0db61e0..52ad320 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -19,7 +19,7 @@
class Paraformer(nn.Module):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@@ -112,7 +112,7 @@
class BiCifParaformer(nn.Module):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
diff --git a/funasr/export/models/e2e_vad.py b/funasr/export/models/e2e_vad.py
new file mode 100644
index 0000000..d3e8f30
--- /dev/null
+++ b/funasr/export/models/e2e_vad.py
@@ -0,0 +1,60 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+
+import torch
+from torch import nn
+import math
+
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
+
+class E2EVadModel(nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ feats_dim=400,
+ model_name='model',
+ **kwargs,):
+ super(E2EVadModel, self).__init__()
+ self.feats_dim = feats_dim
+ self.max_seq_len = max_seq_len
+ self.model_name = model_name
+ if isinstance(model.encoder, FSMN):
+ self.encoder = FSMN_export(model.encoder)
+ else:
+ raise "unsupported encoder"
+
+
+ def forward(self, feats: torch.Tensor, *args, ):
+
+ scores, out_caches = self.encoder(feats, *args)
+ return scores, out_caches
+
+ def get_dummy_inputs(self, frame=30):
+ speech = torch.randn(1, frame, self.feats_dim)
+ in_cache0 = torch.randn(1, 128, 19, 1)
+ in_cache1 = torch.randn(1, 128, 19, 1)
+ in_cache2 = torch.randn(1, 128, 19, 1)
+ in_cache3 = torch.randn(1, 128, 19, 1)
+
+ return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
+
+ # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+ # import numpy as np
+ # fbank = np.loadtxt(txt_file)
+ # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+ # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+ # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+ # return (speech, speech_lengths)
+
+ def get_input_names(self):
+ return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
+
+ def get_output_names(self):
+ return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
+
+ def get_dynamic_axes(self):
+ return {
+ 'speech': {
+ 1: 'feats_length'
+ },
+ }
diff --git a/funasr/export/models/encoder/fsmn_encoder.py b/funasr/export/models/encoder/fsmn_encoder.py
new file mode 100755
index 0000000..b8e6433
--- /dev/null
+++ b/funasr/export/models/encoder/fsmn_encoder.py
@@ -0,0 +1,296 @@
+from typing import Tuple, Dict
+import copy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.encoder.fsmn_encoder import BasicBlock
+
+class LinearTransform(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(LinearTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
+
+ def forward(self, input):
+ output = self.linear(input)
+
+ return output
+
+
+class AffineTransform(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(AffineTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, input):
+ output = self.linear(input)
+
+ return output
+
+
+class RectifiedLinear(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(RectifiedLinear, self).__init__()
+ self.dim = input_dim
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, input):
+ out = self.relu(input)
+ return out
+
+
+class FSMNBlock(nn.Module):
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ lorder=None,
+ rorder=None,
+ lstride=1,
+ rstride=1,
+ ):
+ super(FSMNBlock, self).__init__()
+
+ self.dim = input_dim
+
+ if lorder is None:
+ return
+
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+
+ self.conv_left = nn.Conv2d(
+ self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
+
+ if self.rorder > 0:
+ self.conv_right = nn.Conv2d(
+ self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
+ else:
+ self.conv_right = None
+
+ def forward(self, input: torch.Tensor, cache: torch.Tensor):
+ x = torch.unsqueeze(input, 1)
+ x_per = x.permute(0, 3, 2, 1) # B D T C
+
+ cache = cache.to(x_per.device)
+ y_left = torch.cat((cache, x_per), dim=2)
+ cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+ y_left = self.conv_left(y_left)
+ out = x_per + y_left
+
+ if self.conv_right is not None:
+ # maybe need to check
+ y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
+ y_right = y_right[:, :, self.rstride:, :]
+ y_right = self.conv_right(y_right)
+ out += y_right
+
+ out_per = out.permute(0, 3, 2, 1)
+ output = out_per.squeeze(1)
+
+ return output, cache
+
+
+class BasicBlock_export(nn.Module):
+ def __init__(self,
+ model,
+ ):
+ super(BasicBlock_export, self).__init__()
+ self.linear = model.linear
+ self.fsmn_block = model.fsmn_block
+ self.affine = model.affine
+ self.relu = model.relu
+
+ def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
+ x = self.linear(input) # B T D
+ # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ # if cache_layer_name not in in_cache:
+ # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x, out_cache = self.fsmn_block(x, in_cache)
+ x = self.affine(x)
+ x = self.relu(x)
+ return x, out_cache
+
+
+# class FsmnStack(nn.Sequential):
+# def __init__(self, *args):
+# super(FsmnStack, self).__init__(*args)
+#
+# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
+# x = input
+# for module in self._modules.values():
+# x = module(x, in_cache)
+# return x
+
+
+'''
+FSMN net for keyword spotting
+input_dim: input dimension
+linear_dim: fsmn input dimensionll
+proj_dim: fsmn projection dimension
+lorder: fsmn left order
+rorder: fsmn right order
+num_syn: output dimension
+fsmn_layers: no. of sequential fsmn layers
+'''
+
+
+class FSMN(nn.Module):
+ def __init__(
+ self, model,
+ ):
+ super(FSMN, self).__init__()
+
+ # self.input_dim = input_dim
+ # self.input_affine_dim = input_affine_dim
+ # self.fsmn_layers = fsmn_layers
+ # self.linear_dim = linear_dim
+ # self.proj_dim = proj_dim
+ # self.output_affine_dim = output_affine_dim
+ # self.output_dim = output_dim
+ #
+ # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+ # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+ # self.relu = RectifiedLinear(linear_dim, linear_dim)
+ # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+ # range(fsmn_layers)])
+ # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+ # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+ # self.softmax = nn.Softmax(dim=-1)
+ self.in_linear1 = model.in_linear1
+ self.in_linear2 = model.in_linear2
+ self.relu = model.relu
+ # self.fsmn = model.fsmn
+ self.out_linear1 = model.out_linear1
+ self.out_linear2 = model.out_linear2
+ self.softmax = model.softmax
+ self.fsmn = model.fsmn
+ for i, d in enumerate(model.fsmn):
+ if isinstance(d, BasicBlock):
+ self.fsmn[i] = BasicBlock_export(d)
+
+ def fuse_modules(self):
+ pass
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ *args,
+ ):
+ """
+ Args:
+ input (torch.Tensor): Input tensor (B, T, D)
+ in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+ {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+ """
+
+ x = self.in_linear1(input)
+ x = self.in_linear2(x)
+ x = self.relu(x)
+ # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
+ out_caches = list()
+ for i, d in enumerate(self.fsmn):
+ in_cache = args[i]
+ x, out_cache = d(x, in_cache)
+ out_caches.append(out_cache)
+ x = self.out_linear1(x)
+ x = self.out_linear2(x)
+ x = self.softmax(x)
+
+ return x, out_caches
+
+
+'''
+one deep fsmn layer
+dimproj: projection dimension, input and output dimension of memory blocks
+dimlinear: dimension of mapping layer
+lorder: left order
+rorder: right order
+lstride: left stride
+rstride: right stride
+'''
+
+
+class DFSMN(nn.Module):
+
+ def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
+ super(DFSMN, self).__init__()
+
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+
+ self.expand = AffineTransform(dimproj, dimlinear)
+ self.shrink = LinearTransform(dimlinear, dimproj)
+
+ self.conv_left = nn.Conv2d(
+ dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
+
+ if rorder > 0:
+ self.conv_right = nn.Conv2d(
+ dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
+ else:
+ self.conv_right = None
+
+ def forward(self, input):
+ f1 = F.relu(self.expand(input))
+ p1 = self.shrink(f1)
+
+ x = torch.unsqueeze(p1, 1)
+ x_per = x.permute(0, 3, 2, 1)
+
+ y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
+ if self.conv_right is not None:
+ y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
+ y_right = y_right[:, :, self.rstride:, :]
+ out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
+ else:
+ out = x_per + self.conv_left(y_left)
+
+ out1 = out.permute(0, 3, 2, 1)
+ output = input + out1.squeeze(1)
+
+ return output
+
+
+'''
+build stacked dfsmn layers
+'''
+
+
+def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
+ repeats = [
+ nn.Sequential(
+ DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
+ for i in range(fsmn_layers)
+ ]
+
+ return nn.Sequential(*repeats)
+
+
+if __name__ == '__main__':
+ fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
+ print(fsmn)
+
+ num_params = sum(p.numel() for p in fsmn.parameters())
+ print('the number of model params: {}'.format(num_params))
+ x = torch.zeros(128, 200, 400) # batch-size * time * dim
+ y, _ = fsmn(x) # batch-size * time * dim
+ print('input shape: {}'.format(x.shape))
+ print('output shape: {}'.format(y.shape))
+
+ print(fsmn.to_kaldi_net())
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
index 8a50538..f583f56 100644
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ b/funasr/export/models/encoder/sanm_encoder.py
@@ -9,6 +9,7 @@
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
+
class SANMEncoder(nn.Module):
def __init__(
self,
@@ -107,3 +108,106 @@
}
}
+
+
+class SANMVadEncoder(nn.Module):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='encoder',
+ onnx: bool = True,
+ ):
+ super().__init__()
+ self.embed = model.embed
+ self.model = model
+ self.feats_dim = feats_dim
+ self._output_size = model._output_size
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ if hasattr(model, 'encoders0'):
+ for i, d in enumerate(self.model.encoders0):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+ if isinstance(d.feed_forward, PositionwiseFeedForward):
+ d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+ self.model.encoders0[i] = EncoderLayerSANM_export(d)
+
+ for i, d in enumerate(self.model.encoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+ if isinstance(d.feed_forward, PositionwiseFeedForward):
+ d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+ self.model.encoders[i] = EncoderLayerSANM_export(d)
+
+ self.model_name = model_name
+ self.num_heads = model.encoders[0].self_attn.h
+ self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+ def prepare_mask(self, mask, sub_masks):
+ mask_3d_btd = mask[:, :, None]
+ mask_4d_bhlt = (1 - sub_masks) * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ vad_masks: torch.Tensor,
+ sub_masks: torch.Tensor,
+ ):
+ speech = speech * self._output_size ** 0.5
+ mask = self.make_pad_mask(speech_lengths)
+ vad_masks = self.prepare_mask(mask, vad_masks)
+ mask = self.prepare_mask(mask, sub_masks)
+
+ if self.embed is None:
+ xs_pad = speech
+ else:
+ xs_pad = self.embed(speech)
+
+ encoder_outs = self.model.encoders0(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ # encoder_outs = self.model.encoders(xs_pad, mask)
+ for layer_idx, encoder_layer in enumerate(self.model.encoders):
+ if layer_idx == len(self.model.encoders) - 1:
+ mask = vad_masks
+ encoder_outs = encoder_layer(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ xs_pad = self.model.after_norm(xs_pad)
+
+ return xs_pad, speech_lengths
+
+ def get_output_size(self):
+ return self.model.encoders[0].size
+
+ # def get_dummy_inputs(self):
+ # feats = torch.randn(1, 100, self.feats_dim)
+ # return (feats)
+ #
+ # def get_input_names(self):
+ # return ['feats']
+ #
+ # def get_output_names(self):
+ # return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+ #
+ # def get_dynamic_axes(self):
+ # return {
+ # 'feats': {
+ # 1: 'feats_length'
+ # },
+ # 'encoder_out': {
+ # 1: 'enc_out_length'
+ # },
+ # 'predictor_weight': {
+ # 1: 'pre_out_length'
+ # }
+ #
+ # }
diff --git a/funasr/punctuation/__init__.py b/funasr/export/test/__init__.py
similarity index 100%
rename from funasr/punctuation/__init__.py
rename to funasr/export/test/__init__.py
diff --git a/funasr/export/test_onnx.py b/funasr/export/test/test_onnx.py
similarity index 100%
rename from funasr/export/test_onnx.py
rename to funasr/export/test/test_onnx.py
diff --git a/funasr/export/test/test_onnx_punc.py b/funasr/export/test/test_onnx_punc.py
new file mode 100644
index 0000000..39f85f4
--- /dev/null
+++ b/funasr/export/test/test_onnx_punc.py
@@ -0,0 +1,18 @@
+import onnxruntime
+import numpy as np
+
+
+if __name__ == '__main__':
+ onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
+ sess = onnxruntime.InferenceSession(onnx_path)
+ input_name = [nd.name for nd in sess.get_inputs()]
+ output_name = [nd.name for nd in sess.get_outputs()]
+
+ def _get_feed_dict(text_length):
+ return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
+
+ def _run(feed_dict):
+ output = sess.run(output_name, input_feed=feed_dict)
+ for name, value in zip(output_name, output):
+ print('{}: {}'.format(name, value))
+ _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py
new file mode 100644
index 0000000..86be026
--- /dev/null
+++ b/funasr/export/test/test_onnx_punc_vadrealtime.py
@@ -0,0 +1,22 @@
+import onnxruntime
+import numpy as np
+
+
+if __name__ == '__main__':
+ onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
+ sess = onnxruntime.InferenceSession(onnx_path)
+ input_name = [nd.name for nd in sess.get_inputs()]
+ output_name = [nd.name for nd in sess.get_outputs()]
+
+ def _get_feed_dict(text_length):
+ return {'inputs': np.ones((1, text_length), dtype=np.int64),
+ 'text_lengths': np.array([text_length,], dtype=np.int32),
+ 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
+ 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+ }
+
+ def _run(feed_dict):
+ output = sess.run(output_name, input_feed=feed_dict)
+ for name, value in zip(output_name, output):
+ print('{}: {}'.format(name, value))
+ _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_vad.py b/funasr/export/test/test_onnx_vad.py
new file mode 100644
index 0000000..12f058f
--- /dev/null
+++ b/funasr/export/test/test_onnx_vad.py
@@ -0,0 +1,26 @@
+import onnxruntime
+import numpy as np
+
+
+if __name__ == '__main__':
+ onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
+ sess = onnxruntime.InferenceSession(onnx_path)
+ input_name = [nd.name for nd in sess.get_inputs()]
+ output_name = [nd.name for nd in sess.get_outputs()]
+
+ def _get_feed_dict(feats_length):
+
+ return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
+ 'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
+ 'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
+ 'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
+ 'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
+ }
+
+ def _run(feed_dict):
+ output = sess.run(output_name, input_feed=feed_dict)
+ for name, value in zip(output_name, output):
+ print('{}: {}'.format(name, value.shape))
+
+ _run(_get_feed_dict(100))
+ _run(_get_feed_dict(200))
\ No newline at end of file
diff --git a/funasr/export/test_torchscripts.py b/funasr/export/test/test_torchscripts.py
similarity index 100%
rename from funasr/export/test_torchscripts.py
rename to funasr/export/test/test_torchscripts.py
diff --git a/funasr/lm/abs_model.py b/funasr/lm/abs_model.py
index 0ad1e71..1f3c8a7 100644
--- a/funasr/lm/abs_model.py
+++ b/funasr/lm/abs_model.py
@@ -5,7 +5,17 @@
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+import torch
+import torch.nn.functional as F
+from typeguard import check_argument_types
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract LM class
@@ -27,3 +37,122 @@
self, input: torch.Tensor, hidden: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
+
+
+class LanguageModel(AbsESPnetModel):
+ def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
+ assert check_argument_types()
+ super().__init__()
+ self.lm = lm
+ self.sos = 1
+ self.eos = 2
+
+ # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
+ self.ignore_id = ignore_id
+
+ def nll(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ max_length: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute negative log likelihood(nll)
+
+ Normally, this function is called in batchify_nll.
+ Args:
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ max_lengths: int
+ """
+ batch_size = text.size(0)
+ # For data parallel
+ if max_length is None:
+ text = text[:, : text_lengths.max()]
+ else:
+ text = text[:, :max_length]
+
+ # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
+ # text: (Batch, Length) -> x, y: (Batch, Length + 1)
+ x = F.pad(text, [1, 0], "constant", self.sos)
+ t = F.pad(text, [0, 1], "constant", self.ignore_id)
+ for i, l in enumerate(text_lengths):
+ t[i, l] = self.eos
+ x_lengths = text_lengths + 1
+
+ # 2. Forward Language model
+ # x: (Batch, Length) -> y: (Batch, Length, NVocab)
+ y, _ = self.lm(x, None)
+
+ # 3. Calc negative log likelihood
+ # nll: (BxL,)
+ nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
+ # nll: (BxL,) -> (BxL,)
+ if max_length is None:
+ nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
+ else:
+ nll.masked_fill_(
+ make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
+ 0.0,
+ )
+ # nll: (BxL,) -> (B, L)
+ nll = nll.view(batch_size, -1)
+ return nll, x_lengths
+
+ def batchify_nll(
+ self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute negative log likelihood(nll) from transformer language model
+
+ To avoid OOM, this fuction seperate the input into batches.
+ Then call nll for each batch and combine and return results.
+ Args:
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ batch_size: int, samples each batch contain when computing nll,
+ you may change this to avoid OOM or increase
+
+ """
+ total_num = text.size(0)
+ if total_num <= batch_size:
+ nll, x_lengths = self.nll(text, text_lengths)
+ else:
+ nlls = []
+ x_lengths = []
+ max_length = text_lengths.max()
+
+ start_idx = 0
+ while True:
+ end_idx = min(start_idx + batch_size, total_num)
+ batch_text = text[start_idx:end_idx, :]
+ batch_text_lengths = text_lengths[start_idx:end_idx]
+ # batch_nll: [B * T]
+ batch_nll, batch_x_lengths = self.nll(
+ batch_text, batch_text_lengths, max_length=max_length
+ )
+ nlls.append(batch_nll)
+ x_lengths.append(batch_x_lengths)
+ start_idx = end_idx
+ if start_idx == total_num:
+ break
+ nll = torch.cat(nlls)
+ x_lengths = torch.cat(x_lengths)
+ assert nll.size(0) == total_num
+ assert x_lengths.size(0) == total_num
+ return nll, x_lengths
+
+ def forward(
+ self, text: torch.Tensor, text_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ nll, y_lengths = self.nll(text, text_lengths)
+ ntokens = y_lengths.sum()
+ loss = nll.sum() / ntokens
+ stats = dict(loss=loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self, text: torch.Tensor, text_lengths: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ return {}
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
deleted file mode 100644
index db11b67..0000000
--- a/funasr/lm/espnet_model.py
+++ /dev/null
@@ -1,131 +0,0 @@
-from typing import Dict
-from typing import Optional
-from typing import Tuple
-
-import torch
-import torch.nn.functional as F
-from typeguard import check_argument_types
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.lm.abs_model import AbsLM
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-
-class ESPnetLanguageModel(AbsESPnetModel):
- def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
- assert check_argument_types()
- super().__init__()
- self.lm = lm
- self.sos = 1
- self.eos = 2
-
- # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
- self.ignore_id = ignore_id
-
- def nll(
- self,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- max_length: Optional[int] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute negative log likelihood(nll)
-
- Normally, this function is called in batchify_nll.
- Args:
- text: (Batch, Length)
- text_lengths: (Batch,)
- max_lengths: int
- """
- batch_size = text.size(0)
- # For data parallel
- if max_length is None:
- text = text[:, : text_lengths.max()]
- else:
- text = text[:, :max_length]
-
- # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
- # text: (Batch, Length) -> x, y: (Batch, Length + 1)
- x = F.pad(text, [1, 0], "constant", self.sos)
- t = F.pad(text, [0, 1], "constant", self.ignore_id)
- for i, l in enumerate(text_lengths):
- t[i, l] = self.eos
- x_lengths = text_lengths + 1
-
- # 2. Forward Language model
- # x: (Batch, Length) -> y: (Batch, Length, NVocab)
- y, _ = self.lm(x, None)
-
- # 3. Calc negative log likelihood
- # nll: (BxL,)
- nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
- # nll: (BxL,) -> (BxL,)
- if max_length is None:
- nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
- else:
- nll.masked_fill_(
- make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
- 0.0,
- )
- # nll: (BxL,) -> (B, L)
- nll = nll.view(batch_size, -1)
- return nll, x_lengths
-
- def batchify_nll(
- self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute negative log likelihood(nll) from transformer language model
-
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- text: (Batch, Length)
- text_lengths: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
-
- """
- total_num = text.size(0)
- if total_num <= batch_size:
- nll, x_lengths = self.nll(text, text_lengths)
- else:
- nlls = []
- x_lengths = []
- max_length = text_lengths.max()
-
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_text = text[start_idx:end_idx, :]
- batch_text_lengths = text_lengths[start_idx:end_idx]
- # batch_nll: [B * T]
- batch_nll, batch_x_lengths = self.nll(
- batch_text, batch_text_lengths, max_length=max_length
- )
- nlls.append(batch_nll)
- x_lengths.append(batch_x_lengths)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nlls)
- x_lengths = torch.cat(x_lengths)
- assert nll.size(0) == total_num
- assert x_lengths.size(0) == total_num
- return nll, x_lengths
-
- def forward(
- self, text: torch.Tensor, text_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- nll, y_lengths = self.nll(text, text_lengths)
- ntokens = y_lengths.sum()
- loss = nll.sum() / ntokens
- stats = dict(loss=loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
- return loss, stats, weight
-
- def collect_feats(
- self, text: torch.Tensor, text_lengths: torch.Tensor
- ) -> Dict[str, torch.Tensor]:
- return {}
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
index 3b462e7..78105ab 100644
--- a/funasr/models/decoder/contextual_decoder.py
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -102,7 +102,7 @@
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index 3bfcffc..18cd343 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -104,7 +104,6 @@
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
return x, tgt_mask, memory, memory_mask, cache
def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -152,7 +151,7 @@
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -400,7 +399,7 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
@@ -410,13 +409,13 @@
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
@@ -813,7 +812,7 @@
class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
@@ -1077,7 +1076,7 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
@@ -1087,14 +1086,14 @@
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py
index 5f1bb24..aed7f20 100644
--- a/funasr/models/decoder/transformer_decoder.py
+++ b/funasr/models/decoder/transformer_decoder.py
@@ -405,7 +405,7 @@
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index 0336133..f22f12a 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -36,7 +36,11 @@
import random
import math
class MFCCA(AbsESPnetModel):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """
+ Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
+ MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
+ https://arxiv.org/abs/2210.05265
+ """
def __init__(
self,
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index f1bb2bf..5c8560d 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -44,7 +44,7 @@
class Paraformer(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@@ -612,7 +612,7 @@
class ParaformerBert(Paraformer):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
"""
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index de669f2..3f7011d 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -36,8 +36,12 @@
class DiarSondModel(AbsESPnetModel):
- """Speaker overlap-aware neural diarization model
- reference: https://arxiv.org/abs/2211.10243
+ """
+ Author: Speech Lab, Alibaba Group, China
+ SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+ https://arxiv.org/abs/2211.10243
+ TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+ https://arxiv.org/abs/2303.05397
"""
def __init__(
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index eff63d9..5b21277 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -1,3 +1,7 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+"""
+
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index 887439c..d1367ab 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -32,7 +32,7 @@
class TimestampPredictor(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index ac4db32..ca76244 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -40,7 +40,7 @@
class UniASR(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index e6cd7c0..50ec475 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -35,6 +35,11 @@
class VADXOptions:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(
self,
sample_rate: int = 16000,
@@ -99,6 +104,11 @@
class E2EVadSpeechBufWithDoa(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self):
self.start_ms = 0
self.end_ms = 0
@@ -117,6 +127,11 @@
class E2EVadFrameProb(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
@@ -126,6 +141,11 @@
class WindowDetector(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
@@ -192,7 +212,12 @@
class E2EVadModel(nn.Module):
- def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -229,6 +254,7 @@
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
+ self.frontend = frontend
def AllResetDetection(self):
self.is_final = False
@@ -459,8 +485,8 @@
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
- i].contain_seg_end_point:
+ if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
@@ -477,8 +503,9 @@
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
- self.ComputeDecibel()
+
self.ComputeScores(feats, in_cache)
+ self.ComputeDecibel()
if not is_final:
self.DetectCommonFrames()
else:
diff --git a/funasr/models/encoder/opennmt_encoders/conv_encoder.py b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
index a33e0b7..eec854f 100644
--- a/funasr/models/encoder/opennmt_encoders/conv_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
@@ -67,7 +67,7 @@
class ConvEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Convolution encoder in OpenNMT framework
"""
diff --git a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
index cf77bce..db30f08 100644
--- a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
@@ -117,7 +117,7 @@
class SelfAttentionEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Self attention encoder in OpenNMT framework
"""
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 7d7179a..93695c8 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -406,6 +406,12 @@
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+ https://arxiv.org/abs/2211.10243
+ """
+
super(ResNet34Diar, self).__init__(
input_size,
use_head_conv=use_head_conv,
@@ -633,6 +639,12 @@
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+ https://arxiv.org/abs/2303.05397
+ """
+
super(ResNet34SpL2RegDiar, self).__init__(
input_size,
use_head_conv=use_head_conv,
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 57890ef..7ac9121 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -10,7 +10,7 @@
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
@@ -27,7 +27,7 @@
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-
+from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -117,7 +117,7 @@
class SANMEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -549,7 +549,7 @@
class SANMEncoderChunkOpt(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -958,3 +958,231 @@
var_dict_tf[name_tf].shape))
return var_dict_torch_update
+
+
+class SANMVadEncoder(AbsEncoder):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size : int = 11,
+ sanm_shfit : int = 0,
+ selfattention_layer_type: str = "sanm",
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ SinusoidalPositionEncoder(),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ elif input_layer == "pe":
+ self.embed = SinusoidalPositionEncoder()
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ if selfattention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+
+ elif selfattention_layer_type == "sanm":
+ self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+ encoder_selfattn_layer_args0 = (
+ attention_heads,
+ input_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ kernel_size,
+ sanm_shfit,
+ )
+
+ self.encoders0 = repeat(
+ 1,
+ lambda lnum: EncoderLayerSANM(
+ input_size,
+ output_size,
+ self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.encoders = repeat(
+ num_blocks-1,
+ lambda lnum: EncoderLayerSANM(
+ output_size,
+ output_size,
+ self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ vad_indexes: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+ no_future_masks = masks & sub_masks
+ xs_pad *= self.output_size()**0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+ f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ # xs_pad = self.dropout(xs_pad)
+ mask_tup0 = [masks, no_future_masks]
+ encoder_outs = self.encoders0(xs_pad, mask_tup0)
+ xs_pad, _ = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+
+
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ if layer_idx + 1 == len(self.encoders):
+ # This is last layer.
+ coner_mask = torch.ones(masks.size(0),
+ masks.size(-1),
+ masks.size(-1),
+ device=xs_pad.device,
+ dtype=torch.bool)
+ for word_index, length in enumerate(ilens):
+ coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+ vad_indexes[word_index],
+ device=xs_pad.device)
+ layer_mask = masks & coner_mask
+ else:
+ layer_mask = no_future_masks
+ mask_tup1 = [masks, layer_mask]
+ encoder_outs = encoder_layer(xs_pad, mask_tup1)
+ xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 475a939..203f00e 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -38,7 +38,7 @@
return cmvn
-def apply_cmvn(inputs, cmvn_file): # noqa
+def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
@@ -47,7 +47,6 @@
dtype = inputs.dtype
frame, dim = inputs.shape
- cmvn = load_cmvn(cmvn_file)
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
@@ -111,6 +110,7 @@
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -140,8 +140,8 @@
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -194,8 +194,8 @@
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index e80a915..a5273f8 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -234,6 +234,7 @@
last_fire_place = len_time - 1
last_fire_remainds = 0.0
pre_alphas_length = 0
+ last_fire = False
mask_chunk_peak_predictor = None
if cache is not None:
@@ -251,10 +252,15 @@
if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
last_fire_place = len_time - 1 - i
last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
+ last_fire = True
break
- last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
- cache["cif_hidden"] = hidden[:, last_fire_place:, :]
- cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
+ if last_fire:
+ last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
+ cache["cif_hidden"] = hidden[:, last_fire_place:, :]
+ cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
+ else:
+ cache["cif_hidden"] = hidden
+ cache["cif_alphas"] = alphas
token_num_int = token_num.floor().type(torch.int32).item()
return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
similarity index 92%
rename from funasr/punctuation/target_delay_transformer.py
rename to funasr/models/target_delay_transformer.py
index 219af26..e893c65 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -5,16 +5,19 @@
import torch
import torch.nn as nn
-from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
-from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
-from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
-
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(
self,
vocab_size: int,
diff --git a/funasr/punctuation/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
similarity index 92%
rename from funasr/punctuation/vad_realtime_transformer.py
rename to funasr/models/vad_realtime_transformer.py
index 35224f9..fe298ce 100644
--- a/funasr/punctuation/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -6,12 +6,16 @@
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
+from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):
-
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(
self,
vocab_size: int,
diff --git a/funasr/modules/streaming_utils/chunk_utilis.py b/funasr/modules/streaming_utils/chunk_utilis.py
index ea37c68..ed8b31e 100644
--- a/funasr/modules/streaming_utils/chunk_utilis.py
+++ b/funasr/modules/streaming_utils/chunk_utilis.py
@@ -11,7 +11,7 @@
class overlap_chunk():
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
diff --git a/funasr/punctuation/abs_model.py b/funasr/punctuation/abs_model.py
deleted file mode 100644
index 404d5e8..0000000
--- a/funasr/punctuation/abs_model.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Tuple
-
-import torch
-
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-
-
-class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
- """The abstract class
-
- To share the loss calculation way among different models,
- We uses delegate pattern here:
- The instance of this class should be passed to "LanguageModel"
-
- >>> from funasr.punctuation.abs_model import AbsPunctuation
- >>> punc = AbsPunctuation()
- >>> model = ESPnetPunctuationModel(punc=punc)
-
- This "model" is one of mediator objects for "Task" class.
-
- """
-
- @abstractmethod
- def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def with_vad(self) -> bool:
- raise NotImplementedError
diff --git a/funasr/punctuation/sanm_encoder.py b/funasr/punctuation/sanm_encoder.py
deleted file mode 100644
index 8962093..0000000
--- a/funasr/punctuation/sanm_encoder.py
+++ /dev/null
@@ -1,590 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
-import numpy as np
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.multi_layer_conv import Conv1dLinear
-from funasr.modules.multi_layer_conv import MultiLayeredConv1d
-from funasr.modules.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.subsampling import Conv2dSubsampling
-from funasr.modules.subsampling import Conv2dSubsampling2
-from funasr.modules.subsampling import Conv2dSubsampling6
-from funasr.modules.subsampling import Conv2dSubsampling8
-from funasr.modules.subsampling import TooShortUttError
-from funasr.modules.subsampling import check_short_utt
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.mask import subsequent_mask, vad_mask
-
-class EncoderLayerSANM(nn.Module):
- def __init__(
- self,
- in_size,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayerSANM, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(in_size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.in_size = in_size
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
- self.dropout_rate = dropout_rate
-
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
- """Compute encoded features.
-
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
-
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- if self.in_size == self.size:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- else:
- x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
- )
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
-
- return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-class SANMEncoder(AbsEncoder):
- """
- author: Speech Lab, Alibaba Group, China
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- selfattention_layer_type: str = "sanm",
- ):
- assert check_argument_types()
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- SinusoidalPositionEncoder(),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
-
- elif selfattention_layer_type == "sanm":
- self.encoder_selfattn_layer = MultiHeadedAttentionSANM
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad *= self.output_size()**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- encoder_outs = self.encoders0(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
-class SANMVadEncoder(AbsEncoder):
- """
- author: Speech Lab, Alibaba Group, China
-
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size : int = 11,
- sanm_shfit : int = 0,
- selfattention_layer_type: str = "sanm",
- ):
- assert check_argument_types()
- super().__init__()
- self._output_size = output_size
-
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d2":
- self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- SinusoidalPositionEncoder(),
- )
- elif input_layer is None:
- if input_size == output_size:
- self.embed = None
- else:
- self.embed = torch.nn.Linear(input_size, output_size)
- elif input_layer == "pe":
- self.embed = SinusoidalPositionEncoder()
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
-
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
-
- elif selfattention_layer_type == "sanm":
- self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
- encoder_selfattn_layer_args0 = (
- attention_heads,
- input_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- output_size,
- attention_dropout_rate,
- kernel_size,
- sanm_shfit,
- )
-
- self.encoders0 = repeat(
- 1,
- lambda lnum: EncoderLayerSANM(
- input_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
-
- self.encoders = repeat(
- num_blocks-1,
- lambda lnum: EncoderLayerSANM(
- output_size,
- output_size,
- self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- self.interctc_layer_idx = interctc_layer_idx
- if len(interctc_layer_idx) > 0:
- assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
- self.interctc_use_conditioning = interctc_use_conditioning
- self.conditioning_layer = None
- self.dropout = nn.Dropout(dropout_rate)
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- vad_indexes: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- prev_states: Not to be used now.
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
- no_future_masks = masks & sub_masks
- xs_pad *= self.output_size()**0.5
- if self.embed is None:
- xs_pad = xs_pad
- elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling " +
- f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
-
- # xs_pad = self.dropout(xs_pad)
- mask_tup0 = [masks, no_future_masks]
- encoder_outs = self.encoders0(xs_pad, mask_tup0)
- xs_pad, _ = encoder_outs[0], encoder_outs[1]
- intermediate_outs = []
- #if len(self.interctc_layer_idx) == 0:
- if False:
- # Here, we should not use the repeat operation to do it for all layers.
- encoder_outs = self.encoders(xs_pad, masks)
- xs_pad, masks = encoder_outs[0], encoder_outs[1]
- else:
- for layer_idx, encoder_layer in enumerate(self.encoders):
- if layer_idx + 1 == len(self.encoders):
- # This is last layer.
- coner_mask = torch.ones(masks.size(0),
- masks.size(-1),
- masks.size(-1),
- device=xs_pad.device,
- dtype=torch.bool)
- for word_index, length in enumerate(ilens):
- coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
- vad_indexes[word_index],
- device=xs_pad.device)
- layer_mask = masks & coner_mask
- else:
- layer_mask = no_future_masks
- mask_tup1 = [masks, layer_mask]
- encoder_outs = encoder_layer(xs_pad, mask_tup1)
- xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
-
- if layer_idx + 1 in self.interctc_layer_idx:
- encoder_out = xs_pad
-
- # intermediate outputs are also normalized
- if self.normalize_before:
- encoder_out = self.after_norm(encoder_out)
-
- intermediate_outs.append((layer_idx + 1, encoder_out))
-
- if self.interctc_use_conditioning:
- ctc_out = ctc.softmax(encoder_out)
- xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- if len(intermediate_outs) > 0:
- return (xs_pad, intermediate_outs), olens, None
- return xs_pad, olens, None
-
diff --git a/funasr/punctuation/text_preprocessor.py b/funasr/punctuation/text_preprocessor.py
deleted file mode 100644
index c9c4bac..0000000
--- a/funasr/punctuation/text_preprocessor.py
+++ /dev/null
@@ -1,12 +0,0 @@
-def split_to_mini_sentence(words: list, word_limit: int = 20):
- assert word_limit > 1
- if len(words) <= word_limit:
- return [words]
- sentences = []
- length = len(words)
- sentence_len = length // word_limit
- for i in range(sentence_len):
- sentences.append(words[i * word_limit:(i + 1) * word_limit])
- if length % word_limit > 0:
- sentences.append(words[sentence_len * word_limit:])
- return sentences
diff --git a/funasr/runtime/grpc/CMakeLists.txt b/funasr/runtime/grpc/CMakeLists.txt
index 1d5d9a9..c7727d5 100644
--- a/funasr/runtime/grpc/CMakeLists.txt
+++ b/funasr/runtime/grpc/CMakeLists.txt
@@ -74,7 +74,7 @@
"${_target}.cc")
target_link_libraries(${_target}
rg_grpc_proto
- rapidasr
+ funasr
${EXTRA_LIBS}
${_REFLECTION}
${_GRPC_GRPCPP}
diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md
index 6e3516a..82347be 100644
--- a/funasr/runtime/grpc/Readme.md
+++ b/funasr/runtime/grpc/Readme.md
@@ -53,6 +53,68 @@
python grpc_main_client_mic.py --host $server_ip --port 10108
```
+The `grpc_main_client_mic.py` follows the [original design] (https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc#workflow-in-desgin) by sending audio_data with chunks. If you want to send audio_data in one request, here is an example:
+
+```
+# go to ../python/grpc to find this package
+import paraformer_pb2
+
+
+class RecognizeStub:
+ def __init__(self, channel):
+ self.Recognize = channel.stream_stream(
+ '/paraformer.ASR/Recognize',
+ request_serializer=paraformer_pb2.Request.SerializeToString,
+ response_deserializer=paraformer_pb2.Response.FromString,
+ )
+
+
+async def send(channel, data, speaking, isEnd):
+ stub = RecognizeStub(channel)
+ req = paraformer_pb2.Request()
+ if data:
+ req.audio_data = data
+ req.user = 'zz'
+ req.language = 'zh-CN'
+ req.speaking = speaking
+ req.isEnd = isEnd
+ q = queue.SimpleQueue()
+ q.put(req)
+ return stub.Recognize(iter(q.get, None))
+
+# send the audio data once
+async def grpc_rec(data, grpc_uri):
+ with grpc.insecure_channel(grpc_uri) as channel:
+ b = time.time()
+ response = await send(channel, data, False, False)
+ resp = response.next()
+ text = ''
+ if 'decoding' == resp.action:
+ resp = response.next()
+ if 'finish' == resp.action:
+ text = json.loads(resp.sentence)['text']
+ response = await send(channel, None, False, True)
+ return {
+ 'text': text,
+ 'time': time.time() - b,
+ }
+
+async def test():
+ # fc = FunAsrGrpcClient('127.0.0.1', 9900)
+ # t = await fc.rec(wav.tobytes())
+ # print(t)
+ wav, _ = sf.read('z-10s.wav', dtype='int16')
+ uri = '127.0.0.1:9900'
+ res = await grpc_rec(wav.tobytes(), uri)
+ print(res)
+
+
+if __name__ == '__main__':
+ asyncio.run(test())
+
+```
+
+
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer_server.cc
index 69ce903..2893d4c 100644
--- a/funasr/runtime/grpc/paraformer_server.cc
+++ b/funasr/runtime/grpc/paraformer_server.cc
@@ -15,7 +15,6 @@
#include "paraformer.grpc.pb.h"
#include "paraformer_server.h"
-
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
@@ -24,37 +23,14 @@
using grpc::ServerWriter;
using grpc::Status;
-
using paraformer::Request;
using paraformer::Response;
using paraformer::ASR;
ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
- AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
+ AsrHanlde=FunASRInit(model_path, thread_num, quantize);
std::cout << "ASRServicer init" << std::endl;
init_flag = 0;
-}
-
-void ASRServicer::clear_states(const std::string& user) {
- clear_buffers(user);
- clear_transcriptions(user);
-}
-
-void ASRServicer::clear_buffers(const std::string& user) {
- if (client_buffers.count(user)) {
- client_buffers.erase(user);
- }
-}
-
-void ASRServicer::clear_transcriptions(const std::string& user) {
- if (client_transcription.count(user)) {
- client_transcription.erase(user);
- }
-}
-
-void ASRServicer::disconnect(const std::string& user) {
- clear_states(user);
- std::cout << "Disconnecting user: " << user << std::endl;
}
grpc::Status ASRServicer::Recognize(
@@ -62,10 +38,20 @@
grpc::ServerReaderWriter<Response, Request>* stream) {
Request req;
+ std::unordered_map<std::string, std::string> client_buffers;
+ std::unordered_map<std::string, std::string> client_transcription;
+
while (stream->Read(&req)) {
if (req.isend()) {
std::cout << "asr end" << std::endl;
- disconnect(req.user());
+ // disconnect
+ if (client_buffers.count(req.user())) {
+ client_buffers.erase(req.user());
+ }
+ if (client_transcription.count(req.user())) {
+ client_transcription.erase(req.user());
+ }
+
Response res;
res.set_sentence(
R"({"success": true, "detail": "asr end"})"
@@ -88,7 +74,7 @@
res.set_language(req.language());
stream->Write(res);
} else if (!req.speaking()) {
- if (client_buffers.count(req.user()) == 0) {
+ if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
Response res;
res.set_sentence(
R"({"success": true, "detail": "waiting_for_voice"})"
@@ -99,14 +85,24 @@
stream->Write(res);
}else {
auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
- std::string tmp_data = this->client_buffers[req.user()];
- this->clear_states(req.user());
-
+ if (req.audio_data().size() > 0) {
+ auto& buf = client_buffers[req.user()];
+ buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
+ }
+ std::string tmp_data = client_buffers[req.user()];
+ // clear_states
+ if (client_buffers.count(req.user())) {
+ client_buffers.erase(req.user());
+ }
+ if (client_transcription.count(req.user())) {
+ client_transcription.erase(req.user());
+ }
+
Response res;
res.set_sentence(
R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
);
- int data_len_int = tmp_data.length();
+ int data_len_int = tmp_data.length();
std::string data_len = std::to_string(data_len_int);
std::stringstream ss;
ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
@@ -129,18 +125,15 @@
res.set_user(req.user());
res.set_action("finish");
res.set_language(req.language());
-
-
-
stream->Write(res);
}
else {
- RPASR_RESULT Result= RapidAsrRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
- std::string asr_result = ((RPASR_RECOG_RESULT*)Result)->msg;
+ FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
+ std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
std::string delay_str = std::to_string(end_time - begin_time);
-
+
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
Response res;
std::stringstream ss;
@@ -150,8 +143,7 @@
res.set_user(req.user());
res.set_action("finish");
res.set_language(req.language());
-
-
+
stream->Write(res);
}
}
@@ -165,10 +157,9 @@
res.set_language(req.language());
stream->Write(res);
}
- }
+ }
return Status::OK;
}
-
void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
std::string server_address;
diff --git a/funasr/runtime/grpc/paraformer_server.h b/funasr/runtime/grpc/paraformer_server.h
index e42e041..dba1e45 100644
--- a/funasr/runtime/grpc/paraformer_server.h
+++ b/funasr/runtime/grpc/paraformer_server.h
@@ -15,7 +15,7 @@
#include <chrono>
#include "paraformer.grpc.pb.h"
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
using grpc::Server;
@@ -35,22 +35,16 @@
{
std::string msg;
float snippet_time;
-}RPASR_RECOG_RESULT;
+}FUNASR_RECOG_RESULT;
class ASRServicer final : public ASR::Service {
private:
int init_flag;
- std::unordered_map<std::string, std::string> client_buffers;
- std::unordered_map<std::string, std::string> client_transcription;
public:
ASRServicer(const char* model_path, int thread_num, bool quantize);
- void clear_states(const std::string& user);
- void clear_buffers(const std::string& user);
- void clear_transcriptions(const std::string& user);
- void disconnect(const std::string& user);
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
- RPASR_HANDLE AsrHanlde;
+ FUNASR_HANDLE AsrHanlde;
};
diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 4ffe0f3..6feef92 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -2,24 +2,27 @@
project(FunASRonnx)
-set(CMAKE_CXX_STANDARD 11)
+# set(CMAKE_CXX_STANDARD 11)
+set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+include(TestBigEndian)
+test_big_endian(BIG_ENDIAN)
+if(BIG_ENDIAN)
+ message("Big endian system")
+else()
+ message("Little endian system")
+endif()
+
# for onnxruntime
-
IF(WIN32)
-
-
if(CMAKE_CL_64)
link_directories(${ONNXRUNTIME_DIR}\\lib)
else()
add_definitions(-D_WIN_X86)
endif()
ELSE()
-
-
-link_directories(${ONNXRUNTIME_DIR}/lib)
-
+ link_directories(${ONNXRUNTIME_DIR}/lib)
endif()
add_subdirectory("./third_party/yaml-cpp")
diff --git a/funasr/runtime/onnxruntime/include/Audio.h b/funasr/runtime/onnxruntime/include/Audio.h
index da5e82c..ec49a9f 100644
--- a/funasr/runtime/onnxruntime/include/Audio.h
+++ b/funasr/runtime/onnxruntime/include/Audio.h
@@ -6,6 +6,13 @@
#include <queue>
#include <stdint.h>
+#ifndef model_sample_rate
+#define model_sample_rate 16000
+#endif
+#ifndef WAV_HEADER_SIZE
+#define WAV_HEADER_SIZE 44
+#endif
+
using namespace std;
class AudioFrame {
@@ -32,7 +39,6 @@
int16_t *speech_buff;
int speech_len;
int speech_align_len;
- int16_t sample_rate;
int offset;
float align_size;
int data_type;
@@ -43,10 +49,11 @@
Audio(int data_type, int size);
~Audio();
void disp();
- bool loadwav(const char* filename);
- bool loadwav(const char* buf, int nLen);
- bool loadpcmwav(const char* buf, int nFileLen);
- bool loadpcmwav(const char* filename);
+ bool loadwav(const char* filename, int32_t* sampling_rate);
+ void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
+ bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
+ bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
+ bool loadpcmwav(const char* filename, int32_t* sampling_rate);
int fetch_chunck(float *&dout, int len);
int fetch(float *&dout, int &len, int &flag);
void padding();
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
new file mode 100644
index 0000000..9bc37e7
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -0,0 +1,77 @@
+#pragma once
+
+#ifdef WIN32
+#ifdef _FUNASR_API_EXPORT
+#define _FUNASRAPI __declspec(dllexport)
+#else
+#define _FUNASRAPI __declspec(dllimport)
+#endif
+#else
+#define _FUNASRAPI
+#endif
+
+#ifndef _WIN32
+#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
+#else
+#define FUNASR_CALLBCK_PREFIX __stdcall
+#endif
+
+#ifdef __cplusplus
+
+extern "C" {
+#endif
+
+typedef void* FUNASR_HANDLE;
+typedef void* FUNASR_RESULT;
+typedef unsigned char FUNASR_BOOL;
+
+#define FUNASR_TRUE 1
+#define FUNASR_FALSE 0
+#define QM_DEFAULT_THREAD_NUM 4
+
+typedef enum
+{
+ RASR_NONE=-1,
+ RASRM_CTC_GREEDY_SEARCH=0,
+ RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
+ RASRM_ATTENSION_RESCORING = 2,
+
+}FUNASR_MODE;
+
+typedef enum {
+ FUNASR_MODEL_PADDLE = 0,
+ FUNASR_MODEL_PADDLE_2 = 1,
+ FUNASR_MODEL_K2 = 2,
+ FUNASR_MODEL_PARAFORMER = 3,
+
+}FUNASR_MODEL_TYPE;
+
+typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
+
+// APIs for qmasr
+_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize);
+
+
+// if not give a fnCallback ,it should be NULL
+_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+
+_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex);
+
+_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result);
+
+_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result);
+
+_FUNASRAPI void FunASRUninit(FUNASR_HANDLE Handle);
+
+_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result);
+
+#ifdef __cplusplus
+
+}
+#endif
diff --git a/funasr/runtime/onnxruntime/include/librapidasrapi.h b/funasr/runtime/onnxruntime/include/librapidasrapi.h
deleted file mode 100644
index 918e574..0000000
--- a/funasr/runtime/onnxruntime/include/librapidasrapi.h
+++ /dev/null
@@ -1,77 +0,0 @@
-#pragma once
-
-#ifdef WIN32
-#ifdef _RPASR_API_EXPORT
-#define _RAPIDASRAPI __declspec(dllexport)
-#else
-#define _RAPIDASRAPI __declspec(dllimport)
-#endif
-#else
-#define _RAPIDASRAPI
-#endif
-
-#ifndef _WIN32
-#define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
-#else
-#define RPASR_CALLBCK_PREFIX __stdcall
-#endif
-
-#ifdef __cplusplus
-
-extern "C" {
-#endif
-
-typedef void* RPASR_HANDLE;
-typedef void* RPASR_RESULT;
-typedef unsigned char RPASR_BOOL;
-
-#define RPASR_TRUE 1
-#define RPASR_FALSE 0
-#define QM_DEFAULT_THREAD_NUM 4
-
-typedef enum
-{
- RASR_NONE=-1,
- RASRM_CTC_GREEDY_SEARCH=0,
- RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
- RASRM_ATTENSION_RESCORING = 2,
-
-}RPASR_MODE;
-
-typedef enum {
- RPASR_MODEL_PADDLE = 0,
- RPASR_MODEL_PADDLE_2 = 1,
- RPASR_MODEL_K2 = 2,
- RPASR_MODEL_PARAFORMER = 3,
-
-}RPASR_MODEL_TYPE;
-
-typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
-
-// APIs for qmasr
-_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
-
-
-// if not give a fnCallback ,it should be NULL
-_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback);
-
-_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
-
-_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result);
-
-_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result);
-
-_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE Handle);
-
-_RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result);
-
-#ifdef __cplusplus
-
-}
-#endif
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md
index b234e16..dddb46a 100644
--- a/funasr/runtime/onnxruntime/readme.md
+++ b/funasr/runtime/onnxruntime/readme.md
@@ -1,83 +1,70 @@
-## 蹇�熶娇鐢�
-
-### Windows
-
- 瀹夎Vs2022 鎵撳紑cpp_onnx鐩綍涓嬬殑cmake宸ョ▼锛岀洿鎺� build鍗冲彲銆� 鏈粨搴撳凡缁忓噯澶囧ソ鎵�鏈夌浉鍏充緷璧栧簱銆�
-
- Windows涓嬪凡缁忛缃甪ftw3鍙妎nnxruntime搴�
-
-### Linux
-See the bottom of this page: Building Guidance
-
-### 杩愯绋嬪簭
-
-tester /path/to/models_dir /path/to/wave_file quantize(true or false)
-
-渚嬪锛� tester /data/models /data/test.wav false
-
-/data/models 闇�瑕佸寘鎷涓嬩笁涓枃浠�: config.yaml, am.mvn, model.onnx(or model_quant.onnx)
-
-## 鏀寔骞冲彴
-- Windows
-- Linux/Unix
-
-## 渚濊禆
-- fftw3
-- openblas
-- onnxruntime
-
-## 瀵煎嚭onnx鏍煎紡妯″瀷鏂囦欢
-瀹夎 modelscope涓嶧unASR锛屼緷璧栵細torch锛宼orchaudio锛屽畨瑁呰繃绋媅璇︾粏鍙傝�冩枃妗(https://github.com/alibaba-damo-academy/FunASR/wiki)
+## Demo
```shell
-pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
-git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install --editable ./
+tester /path/models_dir /path/wave_file quantize(true or false)
```
-瀵煎嚭onnx妯″瀷锛孾璇﹁](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)锛屽弬鑰冪ず渚嬶紝浠巑odelscope涓ā鍨嬪鍑猴細
+
+The structure of /path/models_dir
+```
+config.yaml, am.mvn, model.onnx(or model_quant.onnx)
+```
+
+## Steps
+
+### Export onnx
+#### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
+
+```shell
+pip3 install torch torchaudio
+pip install -U modelscope
+pip install -U funasr
+```
+#### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
-## Building Guidance for Linux/Unix
+### Building for Linux/Unix
-```
-git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
-mkdir build
-cd build
+#### Download onnxruntime
+```shell
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
# here we get a copy of onnxruntime for linux 64
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
-# ls
-# onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz
+```
-#install fftw3-dev
-ubuntu: apt install libfftw3-dev
-centos: yum install fftw fftw-devel
+#### Install fftw3
+```shell
+sudo apt install libfftw3-dev #ubuntu
+# sudo yum install fftw fftw-devel #centos
+```
-#install openblas
-bash ./third_party/install_openblas.sh
+#### Install openblas
+```shell
+sudo apt-get install libopenblas-dev #ubuntu
+# sudo yum -y install openblas-devel #centos
+```
-# build
- cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
- make
+#### Build runtime
+```shell
+git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
+mkdir build && cd build
+cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
+make
+```
- # then in the subfolder tester of current direcotry, you will see a program, tester
-
-````
-
-### The structure of a qualified onnxruntime package.
+#### The structure of a qualified onnxruntime package.
```
onnxruntime_xxx
鈹溾攢鈹�鈹�include
鈹斺攢鈹�鈹�lib
```
-## 娉ㄦ剰
-鏈▼搴忓彧鏀寔 閲囨牱鐜�16000hz, 浣嶆繁16bit鐨� **鍗曞0閬�** 闊抽銆�
+### Building for Windows
+Ref to win/
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
diff --git a/funasr/runtime/onnxruntime/src/Audio.cpp b/funasr/runtime/onnxruntime/src/Audio.cpp
index bce3a90..38b6de8 100644
--- a/funasr/runtime/onnxruntime/src/Audio.cpp
+++ b/funasr/runtime/onnxruntime/src/Audio.cpp
@@ -3,10 +3,95 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
+#include <fstream>
+#include <assert.h>
#include "Audio.h"
+#include "precomp.h"
using namespace std;
+
+// see http://soundfile.sapp.org/doc/WaveFormat/
+// Note: We assume little endian here
+struct WaveHeader {
+ bool Validate() const {
+ // F F I R
+ if (chunk_id != 0x46464952) {
+ printf("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
+ return false;
+ }
+ // E V A W
+ if (format != 0x45564157) {
+ printf("Expected format WAVE. Given: 0x%08x\n", format);
+ return false;
+ }
+
+ if (subchunk1_id != 0x20746d66) {
+ printf("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
+ subchunk1_id);
+ return false;
+ }
+
+ if (subchunk1_size != 16) { // 16 for PCM
+ printf("Expected subchunk1_size 16. Given: %d\n",
+ subchunk1_size);
+ return false;
+ }
+
+ if (audio_format != 1) { // 1 for PCM
+ printf("Expected audio_format 1. Given: %d\n", audio_format);
+ return false;
+ }
+
+ if (num_channels != 1) { // we support only single channel for now
+ printf("Expected single channel. Given: %d\n", num_channels);
+ return false;
+ }
+ if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
+ return false;
+ }
+
+ if (block_align != (num_channels * bits_per_sample / 8)) {
+ return false;
+ }
+
+ if (bits_per_sample != 16) { // we support only 16 bits per sample
+ printf("Expected bits_per_sample 16. Given: %d\n",
+ bits_per_sample);
+ return false;
+ }
+ return true;
+ }
+
+ // See https://en.wikipedia.org/wiki/WAV#Metadata and
+ // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
+ void SeekToDataChunk(std::istream &is) {
+ // a t a d
+ while (is && subchunk2_id != 0x61746164) {
+ // const char *p = reinterpret_cast<const char *>(&subchunk2_id);
+ // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
+ // p[1], p[2], p[3], subchunk2_size);
+ is.seekg(subchunk2_size, std::istream::cur);
+ is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t));
+ is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t));
+ }
+ }
+
+ int32_t chunk_id;
+ int32_t chunk_size;
+ int32_t format;
+ int32_t subchunk1_id;
+ int32_t subchunk1_size;
+ int16_t audio_format;
+ int16_t num_channels;
+ int32_t sample_rate;
+ int32_t byte_rate;
+ int16_t block_align;
+ int16_t bits_per_sample;
+ int32_t subchunk2_id; // a tag of this chunk
+ int32_t subchunk2_size; // size of subchunk2
+};
+static_assert(sizeof(WaveHeader) == WAV_HEADER_SIZE, "");
class AudioWindow {
private:
@@ -56,7 +141,7 @@
float frame_length = 400;
float frame_shift = 160;
float num_new_samples =
- ceil((num_samples - 400) / frame_shift) * frame_shift + frame_length;
+ ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
end = start + num_new_samples;
len = (int)num_new_samples;
@@ -111,62 +196,95 @@
void Audio::disp()
{
- printf("Audio time is %f s. len is %d\n", (float)speech_len / 16000,
+ printf("Audio time is %f s. len is %d\n", (float)speech_len / model_sample_rate,
speech_len);
}
float Audio::get_time_len()
{
- return (float)speech_len / 16000;
- //speech_len);
+ return (float)speech_len / model_sample_rate;
}
-bool Audio::loadwav(const char *filename)
+void Audio::wavResample(int32_t sampling_rate, const float *waveform,
+ int32_t n)
{
+ printf(
+ "Creating a resampler:\n"
+ " in_sample_rate: %d\n"
+ " output_sample_rate: %d\n",
+ sampling_rate, static_cast<int32_t>(model_sample_rate));
+ float min_freq =
+ std::min<int32_t>(sampling_rate, model_sample_rate);
+ float lowpass_cutoff = 0.99 * 0.5 * min_freq;
+ int32_t lowpass_filter_width = 6;
+ //FIXME
+ //auto resampler = new LinearResample(
+ // sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
+ auto resampler = std::make_unique<LinearResample>(
+ sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
+ std::vector<float> samples;
+ resampler->Resample(waveform, n, true, &samples);
+ //reset speech_data
+ speech_len = samples.size();
+ if (speech_data != NULL) {
+ free(speech_data);
+ }
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
+ copy(samples.begin(), samples.end(), speech_data);
+}
+
+bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
+{
+ WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
-
+
offset = 0;
-
- FILE *fp;
- fp = fopen(filename, "rb");
- if (fp == nullptr)
+ std::ifstream is(filename, std::ifstream::binary);
+ is.read(reinterpret_cast<char *>(&header), sizeof(header));
+ if(!is){
+ fprintf(stderr, "Failed to read %s\n", filename);
return false;
- fseek(fp, 0, SEEK_END); /*瀹氫綅鍒版枃浠舵湯灏�*/
- uint32_t nFileLen = ftell(fp); /*寰楀埌鏂囦欢澶у皬*/
- fseek(fp, 44, SEEK_SET); /*璺宠繃wav鏂囦欢澶�*/
-
- speech_len = (nFileLen - 44) / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_align_len);
+ }
+
+ *sampling_rate = header.sample_rate;
+ // header.subchunk2_size contains the number of bytes in the data.
+ // As we assume each sample contains two bytes, so it is divided by 2 here
+ speech_len = header.subchunk2_size / 2;
+ speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
- int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
- fclose(fp);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+ is.read(reinterpret_cast<char *>(speech_buff), header.subchunk2_size);
+ if (!is) {
+ fprintf(stderr, "Failed to read %s\n", filename);
+ return false;
+ }
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
-
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
-
return true;
}
@@ -174,57 +292,54 @@
return false;
}
-
-bool Audio::loadwav(const char* buf, int nFileLen)
+bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
{
-
-
-
+ WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
-
offset = 0;
- size_t nOffset = 0;
+ std::memcpy(&header, buf, sizeof(header));
-#define WAV_HEADER_SIZE 44
-
- speech_len = (nFileLen - WAV_HEADER_SIZE) / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
+ *sampling_rate = header.sample_rate;
+ speech_len = header.subchunk2_size / 2;
+ speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)(buf + WAV_HEADER_SIZE), speech_len * sizeof(int16_t));
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
+
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
+ }
+ AudioFrame* frame = new AudioFrame(speech_len);
+ frame_queue.push(frame);
return true;
}
else
return false;
-
}
-
-bool Audio::loadpcmwav(const char* buf, int nBufLen)
+bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
@@ -234,32 +349,28 @@
}
offset = 0;
- size_t nOffset = 0;
-
-
-
speech_len = nBufLen / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
-
-
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
@@ -269,13 +380,10 @@
}
else
return false;
-
-
}
-bool Audio::loadpcmwav(const char* filename)
+bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
{
-
if (speech_data != NULL) {
free(speech_data);
}
@@ -293,34 +401,31 @@
fseek(fp, 0, SEEK_SET);
speech_len = (nFileLen) / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
fclose(fp);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
-
-
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
-
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
+ }
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
-
return true;
}
@@ -328,7 +433,6 @@
return false;
}
-
int Audio::fetch_chunck(float *&dout, int len)
{
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index 2a281eb..d41fcd0 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,43 +1,44 @@
file(GLOB files1 "*.cpp")
+file(GLOB files2 "*.cc")
file(GLOB files4 "paraformer/*.cpp")
set(files ${files1} ${files2} ${files3} ${files4})
# message("${files}")
-add_library(rapidasr ${files})
+add_library(funasr ${files})
if(WIN32)
set(EXTRA_LIBS libfftw3f-3 yaml-cpp)
if(CMAKE_CL_64)
- target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
+ target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
else()
- target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
+ target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
endif()
- target_include_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
+ target_include_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
- target_compile_definitions(rapidasr PUBLIC -D_RPASR_API_EXPORT)
+ target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
else()
set(EXTRA_LIBS fftw3f pthread yaml-cpp)
- target_include_directories(rapidasr PUBLIC "/usr/local/opt/fftw/include")
- target_link_directories(rapidasr PUBLIC "/usr/local/opt/fftw/lib")
+ target_include_directories(funasr PUBLIC "/usr/local/opt/fftw/include")
+ target_link_directories(funasr PUBLIC "/usr/local/opt/fftw/lib")
- target_include_directories(rapidasr PUBLIC "/usr/local/opt/openblas/include")
- target_link_directories(rapidasr PUBLIC "/usr/local/opt/openblas/lib")
+ target_include_directories(funasr PUBLIC "/usr/local/opt/openblas/include")
+ target_link_directories(funasr PUBLIC "/usr/local/opt/openblas/lib")
- target_include_directories(rapidasr PUBLIC "/usr/include")
- target_link_directories(rapidasr PUBLIC "/usr/lib64")
+ target_include_directories(funasr PUBLIC "/usr/include")
+ target_link_directories(funasr PUBLIC "/usr/lib64")
- target_include_directories(rapidasr PUBLIC ${FFTW3F_INCLUDE_DIR})
- target_link_directories(rapidasr PUBLIC ${FFTW3F_LIBRARY_DIR})
+ target_include_directories(funasr PUBLIC ${FFTW3F_INCLUDE_DIR})
+ target_link_directories(funasr PUBLIC ${FFTW3F_LIBRARY_DIR})
include_directories(${ONNXRUNTIME_DIR}/include)
endif()
include_directories(${CMAKE_SOURCE_DIR}/include)
-target_link_libraries(rapidasr PUBLIC onnxruntime ${EXTRA_LIBS})
+target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
diff --git a/funasr/runtime/onnxruntime/src/FeatureExtract.cpp b/funasr/runtime/onnxruntime/src/FeatureExtract.cpp
index 1b0c3c4..6d2826a 100644
--- a/funasr/runtime/onnxruntime/src/FeatureExtract.cpp
+++ b/funasr/runtime/onnxruntime/src/FeatureExtract.cpp
@@ -5,14 +5,10 @@
FeatureExtract::FeatureExtract(int mode) : mode(mode)
{
- fftw_init();
}
FeatureExtract::~FeatureExtract()
{
- fftwf_free(fft_input);
- fftwf_free(fft_out);
- fftwf_destroy_plan(p);
}
void FeatureExtract::reset()
@@ -26,34 +22,25 @@
return fqueue.size();
}
-void FeatureExtract::fftw_init()
+void FeatureExtract::insert(fftwf_plan plan, float *din, int len, int flag)
{
- int fft_size = 512;
- fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
- fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
+ float* fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
+ fftwf_complex* fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
memset(fft_input, 0, sizeof(float) * fft_size);
- p = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
-}
-void FeatureExtract::insert(float *din, int len, int flag)
-{
const float *window = (const float *)&window_hex;
if (mode == 3)
window = (const float *)&window_hamm_hex;
-
- int window_size = 400;
- int fft_size = 512;
- int window_shift = 160;
speech.load(din, len);
int i, j;
float tmp_feature[80];
if (mode == 0 || mode == 2 || mode == 3) {
- int ll = (speech.size() - 400) / 160 + 1;
+ int ll = (speech.size() - window_size) / window_shift + 1;
fqueue.reinit(ll);
}
- for (i = 0; i <= speech.size() - 400; i = i + window_shift) {
+ for (i = 0; i <= speech.size() - window_size; i = i + window_shift) {
float tmp_mean = 0;
for (j = 0; j < window_size; j++) {
tmp_mean += speech[i + j];
@@ -70,7 +57,7 @@
pre_val = cur_val;
}
- fftwf_execute(p);
+ fftwf_execute_dft_r2c(plan, fft_input, fft_out);
melspect((float *)fft_out, tmp_feature);
int tmp_flag = S_MIDDLE;
@@ -80,6 +67,8 @@
fqueue.push(tmp_feature, tmp_flag);
}
speech.update(i);
+ fftwf_free(fft_input);
+ fftwf_free(fft_out);
}
bool FeatureExtract::fetch(Tensor<float> *&dout)
@@ -128,7 +117,6 @@
void FeatureExtract::melspect(float *din, float *dout)
{
float fftmag[256];
-// float tmp;
const float *melcoe = (const float *)melcoe_hex;
int i;
for (i = 0; i < 256; i++) {
diff --git a/funasr/runtime/onnxruntime/src/FeatureExtract.h b/funasr/runtime/onnxruntime/src/FeatureExtract.h
index f16ea3a..8296253 100644
--- a/funasr/runtime/onnxruntime/src/FeatureExtract.h
+++ b/funasr/runtime/onnxruntime/src/FeatureExtract.h
@@ -14,12 +14,11 @@
SpeechWrap speech;
FeatureQueue fqueue;
int mode;
+ int fft_size = 512;
+ int window_size = 400;
+ int window_shift = 160;
- float *fft_input;
- fftwf_complex *fft_out;
- fftwf_plan p;
-
- void fftw_init();
+ //void fftw_init();
void melspect(float *din, float *dout);
void global_cmvn(float *din);
@@ -27,9 +26,9 @@
FeatureExtract(int mode);
~FeatureExtract();
int size();
- int status();
+ //int status();
void reset();
- void insert(float *din, int len, int flag);
+ void insert(fftwf_plan plan, float *din, int len, int flag);
bool fetch(Tensor<float> *&dout);
};
diff --git a/funasr/runtime/onnxruntime/src/Vocab.cpp b/funasr/runtime/onnxruntime/src/Vocab.cpp
index af6312b..b54a6c6 100644
--- a/funasr/runtime/onnxruntime/src/Vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/Vocab.cpp
@@ -13,21 +13,6 @@
{
ifstream in(filename);
loadVocabFromYaml(filename);
-
- /*
- string line;
- if (in) // 鏈夎鏂囦欢
- {
- while (getline(in, line)) // line涓笉鍖呮嫭姣忚鐨勬崲琛岀
- {
- vocab.push_back(line);
- }
- }
- else{
- printf("Cannot load vocab from: %s, there must be file vocab.txt", filename);
- exit(-1);
- }
- */
}
Vocab::~Vocab()
{
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index 11c234e..5198030 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -5,7 +5,7 @@
{
std::string msg;
float snippet_time;
-}RPASR_RECOG_RESULT;
+}FUNASR_RECOG_RESULT;
#ifdef _WIN32
@@ -53,4 +53,4 @@
}
}
-}
\ No newline at end of file
+}
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
new file mode 100644
index 0000000..a2ecf10
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -0,0 +1,184 @@
+#include "precomp.h"
+#ifdef __cplusplus
+
+extern "C" {
+#endif
+
+ // APIs for qmasr
+ _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
+ {
+ Model* mm = create_model(szModelDir, nThreadNum, quantize);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+ {
+ Model* pRecogObj = (Model*)handle;
+ if (!pRecogObj)
+ return nullptr;
+
+ int32_t sampling_rate = -1;
+ Audio audio(1);
+ if (!audio.loadwav(szBuf, nLen, &sampling_rate))
+ return nullptr;
+ //audio.split();
+
+ float* buff;
+ int len;
+ int flag=0;
+ FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
+ pResult->snippet_time = audio.get_time_len();
+ int nStep = 0;
+ int nTotal = audio.get_queue_size();
+ while (audio.fetch(buff, len, flag) > 0) {
+ //pRecogObj->reset();
+ string msg = pRecogObj->forward(buff, len, flag);
+ pResult->msg += msg;
+ nStep++;
+ if (fnCallback)
+ fnCallback(nStep, nTotal);
+ }
+
+ return pResult;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+ {
+ Model* pRecogObj = (Model*)handle;
+ if (!pRecogObj)
+ return nullptr;
+
+ Audio audio(1);
+ if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
+ return nullptr;
+ //audio.split();
+
+ float* buff;
+ int len;
+ int flag = 0;
+ FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
+ pResult->snippet_time = audio.get_time_len();
+ int nStep = 0;
+ int nTotal = audio.get_queue_size();
+ while (audio.fetch(buff, len, flag) > 0) {
+ //pRecogObj->reset();
+ string msg = pRecogObj->forward(buff, len, flag);
+ pResult->msg += msg;
+ nStep++;
+ if (fnCallback)
+ fnCallback(nStep, nTotal);
+ }
+
+ return pResult;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+ {
+ Model* pRecogObj = (Model*)handle;
+ if (!pRecogObj)
+ return nullptr;
+
+ Audio audio(1);
+ if (!audio.loadpcmwav(szFileName, &sampling_rate))
+ return nullptr;
+ //audio.split();
+
+ float* buff;
+ int len;
+ int flag = 0;
+ FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
+ pResult->snippet_time = audio.get_time_len();
+ int nStep = 0;
+ int nTotal = audio.get_queue_size();
+ while (audio.fetch(buff, len, flag) > 0) {
+ //pRecogObj->reset();
+ string msg = pRecogObj->forward(buff, len, flag);
+ pResult->msg += msg;
+ nStep++;
+ if (fnCallback)
+ fnCallback(nStep, nTotal);
+ }
+
+ return pResult;
+ }
+
+ _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+ {
+ Model* pRecogObj = (Model*)handle;
+ if (!pRecogObj)
+ return nullptr;
+
+ int32_t sampling_rate = -1;
+ Audio audio(1);
+ if(!audio.loadwav(szWavfile, &sampling_rate))
+ return nullptr;
+ //audio.split();
+
+ float* buff;
+ int len;
+ int flag = 0;
+ int nStep = 0;
+ int nTotal = audio.get_queue_size();
+ FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
+ pResult->snippet_time = audio.get_time_len();
+ while (audio.fetch(buff, len, flag) > 0) {
+ //pRecogObj->reset();
+ string msg = pRecogObj->forward(buff, len, flag);
+ pResult->msg+= msg;
+ nStep++;
+ if (fnCallback)
+ fnCallback(nStep, nTotal);
+ }
+
+ return pResult;
+ }
+
+ _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
+ {
+ if (!Result)
+ return 0;
+
+ return 1;
+ }
+
+
+ _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
+ {
+ if (!Result)
+ return 0.0f;
+
+ return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
+ }
+
+ _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
+ {
+ FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
+ if(!pResult)
+ return nullptr;
+
+ return pResult->msg.c_str();
+ }
+
+ _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
+ {
+ if (Result)
+ {
+ delete (FUNASR_RECOG_RESULT*)Result;
+ }
+ }
+
+ _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
+ {
+ Model* pRecogObj = (Model*)handle;
+
+ if (!pRecogObj)
+ return;
+
+ delete pRecogObj;
+ }
+
+#ifdef __cplusplus
+
+}
+#endif
+
diff --git a/funasr/runtime/onnxruntime/src/librapidasrapi.cpp b/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
deleted file mode 100644
index 62f47a5..0000000
--- a/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
+++ /dev/null
@@ -1,182 +0,0 @@
-#include "precomp.h"
-#ifdef __cplusplus
-
-extern "C" {
-#endif
-
- // APIs for qmasr
- _RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
- {
- Model* mm = create_model(szModelDir, nThreadNum, quantize);
- return mm;
- }
-
- _RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
- {
- Model* pRecogObj = (Model*)handle;
- if (!pRecogObj)
- return nullptr;
-
- Audio audio(1);
- if (!audio.loadwav(szBuf, nLen))
- return nullptr;
- //audio.split();
-
- float* buff;
- int len;
- int flag=0;
- RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
- pResult->snippet_time = audio.get_time_len();
- int nStep = 0;
- int nTotal = audio.get_queue_size();
- while (audio.fetch(buff, len, flag) > 0) {
- pRecogObj->reset();
- string msg = pRecogObj->forward(buff, len, flag);
- pResult->msg += msg;
- nStep++;
- if (fnCallback)
- fnCallback(nStep, nTotal);
- }
-
- return pResult;
- }
-
- _RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
- {
- Model* pRecogObj = (Model*)handle;
- if (!pRecogObj)
- return nullptr;
-
- Audio audio(1);
- if (!audio.loadpcmwav(szBuf, nLen))
- return nullptr;
- //audio.split();
-
- float* buff;
- int len;
- int flag = 0;
- RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
- pResult->snippet_time = audio.get_time_len();
- int nStep = 0;
- int nTotal = audio.get_queue_size();
- while (audio.fetch(buff, len, flag) > 0) {
- pRecogObj->reset();
- string msg = pRecogObj->forward(buff, len, flag);
- pResult->msg += msg;
- nStep++;
- if (fnCallback)
- fnCallback(nStep, nTotal);
- }
-
- return pResult;
- }
-
- _RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
- {
- Model* pRecogObj = (Model*)handle;
- if (!pRecogObj)
- return nullptr;
-
- Audio audio(1);
- if (!audio.loadpcmwav(szFileName))
- return nullptr;
- //audio.split();
-
- float* buff;
- int len;
- int flag = 0;
- RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
- pResult->snippet_time = audio.get_time_len();
- int nStep = 0;
- int nTotal = audio.get_queue_size();
- while (audio.fetch(buff, len, flag) > 0) {
- pRecogObj->reset();
- string msg = pRecogObj->forward(buff, len, flag);
- pResult->msg += msg;
- nStep++;
- if (fnCallback)
- fnCallback(nStep, nTotal);
- }
-
- return pResult;
- }
-
- _RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
- {
- Model* pRecogObj = (Model*)handle;
- if (!pRecogObj)
- return nullptr;
-
- Audio audio(1);
- if(!audio.loadwav(szWavfile))
- return nullptr;
- //audio.split();
-
- float* buff;
- int len;
- int flag = 0;
- int nStep = 0;
- int nTotal = audio.get_queue_size();
- RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
- pResult->snippet_time = audio.get_time_len();
- while (audio.fetch(buff, len, flag) > 0) {
- pRecogObj->reset();
- string msg = pRecogObj->forward(buff, len, flag);
- pResult->msg+= msg;
- nStep++;
- if (fnCallback)
- fnCallback(nStep, nTotal);
- }
-
- return pResult;
- }
-
- _RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result)
- {
- if (!Result)
- return 0;
-
- return 1;
- }
-
-
- _RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result)
- {
- if (!Result)
- return 0.0f;
-
- return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
- }
-
- _RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
- {
- RPASR_RECOG_RESULT * pResult = (RPASR_RECOG_RESULT*)Result;
- if(!pResult)
- return nullptr;
-
- return pResult->msg.c_str();
- }
-
- _RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
- {
- if (Result)
- {
- delete (RPASR_RECOG_RESULT*)Result;
- }
- }
-
- _RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
- {
- Model* pRecogObj = (Model*)handle;
-
- if (!pRecogObj)
- return;
-
- delete pRecogObj;
- }
-
-#ifdef __cplusplus
-
-}
-#endif
-
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index a49069c..695e0f7 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,7 +4,7 @@
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
-{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
string model_path;
string cmvn_path;
string config_path;
@@ -18,7 +18,10 @@
cmvn_path = pathAppend(path, "am.mvn");
config_path = pathAppend(path, "config.yaml");
- fe = new FeatureExtract(3);
+ fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
+ fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
+ memset(fft_input, 0, sizeof(float) * fft_size);
+ plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
//sessionOptions.SetInterOpNumThreads(1);
sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -26,20 +29,20 @@
#ifdef _WIN32
wstring wstrPath = strToWstr(model_path);
- m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
+ m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
- m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
+ m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
string strName;
- getInputName(m_session, strName);
+ getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
- getInputName(m_session, strName,1);
+ getInputName(m_session.get(), strName,1);
m_strInputNames.push_back(strName);
- getOutputName(m_session, strName);
+ getOutputName(m_session.get(), strName);
m_strOutputNames.push_back(strName);
- getOutputName(m_session, strName,1);
+ getOutputName(m_session.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@@ -52,20 +55,16 @@
ModelImp::~ModelImp()
{
- if(fe)
- delete fe;
- if (m_session)
- {
- delete m_session;
- m_session = nullptr;
- }
if(vocab)
delete vocab;
+ fftwf_free(fft_input);
+ fftwf_free(fft_out);
+ fftwf_destroy_plan(plan);
+ fftwf_cleanup();
}
void ModelImp::reset()
{
- fe->reset();
}
void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -159,13 +158,20 @@
string ModelImp::forward(float* din, int len, int flag)
{
-
Tensor<float>* in;
- fe->insert(din, len, flag);
+ FeatureExtract* fe = new FeatureExtract(3);
+ fe->reset();
+ fe->insert(plan, din, len, flag);
fe->fetch(in);
apply_lfr(in);
apply_cmvn(in);
Ort::RunOptions run_option;
+
+#ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
@@ -192,7 +198,6 @@
auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
-
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@@ -203,9 +208,14 @@
result = "";
}
-
- if(in)
+ if(in){
delete in;
+ in = nullptr;
+ }
+ if(fe){
+ delete fe;
+ fe = nullptr;
+ }
return result;
}
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index 395c328..8946ae1 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -8,7 +8,10 @@
class ModelImp : public Model {
private:
- FeatureExtract* fe;
+ int fft_size=512;
+ float *fft_input;
+ fftwf_complex *fft_out;
+ fftwf_plan plan;
Vocab* vocab;
vector<float> means_list;
@@ -21,21 +24,13 @@
string greedy_search( float* in, int nLen);
-#ifdef _WIN_X86
- Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
-#else
- Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
-#endif
-
- Ort::Session* m_session = nullptr;
- Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
- Ort::SessionOptions sessionOptions = Ort::SessionOptions();
+ std::unique_ptr<Ort::Session> m_session;
+ Ort::Env env_;
+ Ort::SessionOptions sessionOptions;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
- //string m_strInputName, m_strInputNameLen;
- //string m_strOutputName, m_strOutputNameLen;
public:
ModelImp(const char* path, int nNumThread=0, bool quantize=false);
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index c9f43bf..3aeed14 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -44,9 +44,10 @@
#include "FeatureQueue.h"
#include "SpeechWrap.h"
#include <Audio.h>
+#include "resample.h"
#include "Model.h"
#include "paraformer_onnx.h"
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
using namespace paraformer;
diff --git a/funasr/runtime/onnxruntime/src/resample.cc b/funasr/runtime/onnxruntime/src/resample.cc
new file mode 100644
index 0000000..0238752
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/resample.cc
@@ -0,0 +1,305 @@
+/**
+ * Copyright 2013 Pegah Ghahremani
+ * 2014 IMSL, PKU-HKUST (author: Wei Shi)
+ * 2014 Yanqing Sun, Junjie Wang
+ * 2014 Johns Hopkins University (author: Daniel Povey)
+ * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// this file is copied and modified from
+// kaldi/src/feat/resample.cc
+
+#include "resample.h"
+
+#include <assert.h>
+#include <math.h>
+#include <stdio.h>
+
+#include <cstdlib>
+#include <type_traits>
+
+#ifndef M_2PI
+#define M_2PI 6.283185307179586476925286766559005
+#endif
+
+#ifndef M_PI
+#define M_PI 3.1415926535897932384626433832795
+#endif
+
+template <class I>
+I Gcd(I m, I n) {
+ // this function is copied from kaldi/src/base/kaldi-math.h
+ if (m == 0 || n == 0) {
+ if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
+ fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
+ exit(-1);
+ }
+ return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
+ // return absolute value of whichever is nonzero
+ }
+ // could use compile-time assertion
+ // but involves messing with complex template stuff.
+ static_assert(std::is_integral<I>::value, "");
+ while (1) {
+ m %= n;
+ if (m == 0) return (n > 0 ? n : -n);
+ n %= m;
+ if (n == 0) return (m > 0 ? m : -m);
+ }
+}
+
+/// Returns the least common multiple of two integers. Will
+/// crash unless the inputs are positive.
+template <class I>
+I Lcm(I m, I n) {
+ // This function is copied from kaldi/src/base/kaldi-math.h
+ assert(m > 0 && n > 0);
+ I gcd = Gcd(m, n);
+ return gcd * (m / gcd) * (n / gcd);
+}
+
+static float DotProduct(const float *a, const float *b, int32_t n) {
+ float sum = 0;
+ for (int32_t i = 0; i != n; ++i) {
+ sum += a[i] * b[i];
+ }
+ return sum;
+}
+
+LinearResample::LinearResample(int32_t samp_rate_in_hz,
+ int32_t samp_rate_out_hz, float filter_cutoff_hz,
+ int32_t num_zeros)
+ : samp_rate_in_(samp_rate_in_hz),
+ samp_rate_out_(samp_rate_out_hz),
+ filter_cutoff_(filter_cutoff_hz),
+ num_zeros_(num_zeros) {
+ assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 &&
+ filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz &&
+ filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0);
+
+ // base_freq is the frequency of the repeating unit, which is the gcd
+ // of the input frequencies.
+ int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_);
+ input_samples_in_unit_ = samp_rate_in_ / base_freq;
+ output_samples_in_unit_ = samp_rate_out_ / base_freq;
+
+ SetIndexesAndWeights();
+ Reset();
+}
+
+void LinearResample::SetIndexesAndWeights() {
+ first_index_.resize(output_samples_in_unit_);
+ weights_.resize(output_samples_in_unit_);
+
+ double window_width = num_zeros_ / (2.0 * filter_cutoff_);
+
+ for (int32_t i = 0; i < output_samples_in_unit_; i++) {
+ double output_t = i / static_cast<double>(samp_rate_out_);
+ double min_t = output_t - window_width, max_t = output_t + window_width;
+ // we do ceil on the min and floor on the max, because if we did it
+ // the other way around we would unnecessarily include indexes just
+ // outside the window, with zero coefficients. It's possible
+ // if the arguments to the ceil and floor expressions are integers
+ // (e.g. if filter_cutoff_ has an exact ratio with the sample rates),
+ // that we unnecessarily include something with a zero coefficient,
+ // but this is only a slight efficiency issue.
+ int32_t min_input_index = ceil(min_t * samp_rate_in_),
+ max_input_index = floor(max_t * samp_rate_in_),
+ num_indices = max_input_index - min_input_index + 1;
+ first_index_[i] = min_input_index;
+ weights_[i].resize(num_indices);
+ for (int32_t j = 0; j < num_indices; j++) {
+ int32_t input_index = min_input_index + j;
+ double input_t = input_index / static_cast<double>(samp_rate_in_),
+ delta_t = input_t - output_t;
+ // sign of delta_t doesn't matter.
+ weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_;
+ }
+ }
+}
+
+/** Here, t is a time in seconds representing an offset from
+ the center of the windowed filter function, and FilterFunction(t)
+ returns the windowed filter function, described
+ in the header as h(t) = f(t)g(t), evaluated at t.
+*/
+float LinearResample::FilterFunc(float t) const {
+ float window, // raised-cosine (Hanning) window of width
+ // num_zeros_/2*filter_cutoff_
+ filter; // sinc filter function
+ if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
+ window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t));
+ else
+ window = 0.0; // outside support of window function
+ if (t != 0)
+ filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t);
+ else
+ filter = 2 * filter_cutoff_; // limit of the function at t = 0
+ return filter * window;
+}
+
+void LinearResample::Reset() {
+ input_sample_offset_ = 0;
+ output_sample_offset_ = 0;
+ input_remainder_.resize(0);
+}
+
+void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
+ std::vector<float> *output) {
+ int64_t tot_input_samp = input_sample_offset_ + input_dim,
+ tot_output_samp = GetNumOutputSamples(tot_input_samp, flush);
+
+ assert(tot_output_samp >= output_sample_offset_);
+
+ output->resize(tot_output_samp - output_sample_offset_);
+
+ // samp_out is the index into the total output signal, not just the part
+ // of it we are producing here.
+ for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp;
+ samp_out++) {
+ int64_t first_samp_in;
+ int32_t samp_out_wrapped;
+ GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped);
+ const std::vector<float> &weights = weights_[samp_out_wrapped];
+ // first_input_index is the first index into "input" that we have a weight
+ // for.
+ int32_t first_input_index =
+ static_cast<int32_t>(first_samp_in - input_sample_offset_);
+ float this_output;
+ if (first_input_index >= 0 &&
+ first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) {
+ this_output =
+ DotProduct(input + first_input_index, weights.data(), weights.size());
+ } else { // Handle edge cases.
+ this_output = 0.0;
+ for (int32_t i = 0; i < static_cast<int32_t>(weights.size()); i++) {
+ float weight = weights[i];
+ int32_t input_index = first_input_index + i;
+ if (input_index < 0 &&
+ static_cast<int32_t>(input_remainder_.size()) + input_index >= 0) {
+ this_output +=
+ weight * input_remainder_[input_remainder_.size() + input_index];
+ } else if (input_index >= 0 && input_index < input_dim) {
+ this_output += weight * input[input_index];
+ } else if (input_index >= input_dim) {
+ // We're past the end of the input and are adding zero; should only
+ // happen if the user specified flush == true, or else we would not
+ // be trying to output this sample.
+ assert(flush);
+ }
+ }
+ }
+ int32_t output_index =
+ static_cast<int32_t>(samp_out - output_sample_offset_);
+ (*output)[output_index] = this_output;
+ }
+
+ if (flush) {
+ Reset(); // Reset the internal state.
+ } else {
+ SetRemainder(input, input_dim);
+ input_sample_offset_ = tot_input_samp;
+ output_sample_offset_ = tot_output_samp;
+ }
+}
+
+int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
+ bool flush) const {
+ // For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
+ // where tick_freq is the least common multiple of samp_rate_in_ and
+ // samp_rate_out_.
+ int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_);
+ int32_t ticks_per_input_period = tick_freq / samp_rate_in_;
+
+ // work out the number of ticks in the time interval
+ // [ 0, input_num_samp/samp_rate_in_ ).
+ int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period;
+ if (!flush) {
+ float window_width = num_zeros_ / (2.0 * filter_cutoff_);
+ // To count the window-width in ticks we take the floor. This
+ // is because since we're looking for the largest integer num-out-samp
+ // that fits in the interval, which is open on the right, a reduction
+ // in interval length of less than a tick will never make a difference.
+ // For example, the largest integer in the interval [ 0, 2 ) and the
+ // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
+ // So when we're subtracting the window-width we can ignore the fractional
+ // part.
+ int32_t window_width_ticks = floor(window_width * tick_freq);
+ // The time-period of the output that we can sample gets reduced
+ // by the window-width (which is actually the distance from the
+ // center to the edge of the windowing function) if we're not
+ // "flushing the output".
+ interval_length_in_ticks -= window_width_ticks;
+ }
+ if (interval_length_in_ticks <= 0) return 0;
+
+ int32_t ticks_per_output_period = tick_freq / samp_rate_out_;
+ // Get the last output-sample in the closed interval, i.e. replacing [ ) with
+ // [ ]. Note: integer division rounds down. See
+ // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
+ // the notation.
+ int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period;
+ // We need the last output-sample in the open interval, so if it takes us to
+ // the end of the interval exactly, subtract one.
+ if (last_output_samp * ticks_per_output_period == interval_length_in_ticks)
+ last_output_samp--;
+
+ // First output-sample index is zero, so the number of output samples
+ // is the last output-sample plus one.
+ int64_t num_output_samp = last_output_samp + 1;
+ return num_output_samp;
+}
+
+// inline
+void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in,
+ int32_t *samp_out_wrapped) const {
+ // A unit is the smallest nonzero amount of time that is an exact
+ // multiple of the input and output sample periods. The unit index
+ // is the answer to "which numbered unit we are in".
+ int64_t unit_index = samp_out / output_samples_in_unit_;
+ // samp_out_wrapped is equal to samp_out % output_samples_in_unit_
+ *samp_out_wrapped =
+ static_cast<int32_t>(samp_out - unit_index * output_samples_in_unit_);
+ *first_samp_in =
+ first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_;
+}
+
+void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
+ std::vector<float> old_remainder(input_remainder_);
+ // max_remainder_needed is the width of the filter from side to side,
+ // measured in input samples. you might think it should be half that,
+ // but you have to consider that you might be wanting to output samples
+ // that are "in the past" relative to the beginning of the latest
+ // input... anyway, storing more remainder than needed is not harmful.
+ int32_t max_remainder_needed =
+ ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
+ input_remainder_.resize(max_remainder_needed);
+ for (int32_t index = -static_cast<int32_t>(input_remainder_.size());
+ index < 0; index++) {
+ // we interpret "index" as an offset from the end of "input" and
+ // from the end of input_remainder_.
+ int32_t input_index = index + input_dim;
+ if (input_index >= 0) {
+ input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
+ input[input_index];
+ } else if (input_index + static_cast<int32_t>(old_remainder.size()) >= 0) {
+ input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
+ old_remainder[input_index +
+ static_cast<int32_t>(old_remainder.size())];
+ // else leave it at zero.
+ }
+ }
+}
diff --git a/funasr/runtime/onnxruntime/src/resample.h b/funasr/runtime/onnxruntime/src/resample.h
new file mode 100644
index 0000000..b9a283a
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/resample.h
@@ -0,0 +1,137 @@
+/**
+ * Copyright 2013 Pegah Ghahremani
+ * 2014 IMSL, PKU-HKUST (author: Wei Shi)
+ * 2014 Yanqing Sun, Junjie Wang
+ * 2014 Johns Hopkins University (author: Daniel Povey)
+ * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// this file is copied and modified from
+// kaldi/src/feat/resample.h
+
+#include <cstdint>
+#include <vector>
+
+
+/*
+ We require that the input and output sampling rate be specified as
+ integers, as this is an easy way to specify that their ratio be rational.
+*/
+
+class LinearResample {
+ public:
+ /// Constructor. We make the input and output sample rates integers, because
+ /// we are going to need to find a common divisor. This should just remind
+ /// you that they need to be integers. The filter cutoff needs to be less
+ /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros
+ /// controls the sharpness of the filter, more == sharper but less efficient.
+ /// We suggest around 4 to 10 for normal use.
+ LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz,
+ float filter_cutoff_hz, int32_t num_zeros);
+
+ /// Calling the function Reset() resets the state of the object prior to
+ /// processing a new signal; it is only necessary if you have called
+ /// Resample(x, x_size, false, y) for some signal, leading to a remainder of
+ /// the signal being called, but then abandon processing the signal before
+ /// calling Resample(x, x_size, true, y) for the last piece. Call it
+ /// unnecessarily between signals will not do any harm.
+ void Reset();
+
+ /// This function does the resampling. If you call it with flush == true and
+ /// you have never called it with flush == false, it just resamples the input
+ /// signal (it resizes the output to a suitable number of samples).
+ ///
+ /// You can also use this function to process a signal a piece at a time.
+ /// suppose you break it into piece1, piece2, ... pieceN. You can call
+ /// \code{.cc}
+ /// Resample(piece1, piece1_size, false, &output1);
+ /// Resample(piece2, piece2_size, false, &output2);
+ /// Resample(piece3, piece3_size, true, &output3);
+ /// \endcode
+ /// If you call it with flush == false, it won't output the last few samples
+ /// but will remember them, so that if you later give it a second piece of
+ /// the input signal it can process it correctly.
+ /// If your most recent call to the object was with flush == false, it will
+ /// have internal state; you can remove this by calling Reset().
+ /// Empty input is acceptable.
+ void Resample(const float *input, int32_t input_dim, bool flush,
+ std::vector<float> *output);
+
+ //// Return the input and output sampling rates (for checks, for example)
+ int32_t GetInputSamplingRate() const { return samp_rate_in_; }
+ int32_t GetOutputSamplingRate() const { return samp_rate_out_; }
+
+ private:
+ void SetIndexesAndWeights();
+
+ float FilterFunc(float) const;
+
+ /// This function outputs the number of output samples we will output
+ /// for a signal with "input_num_samp" input samples. If flush == true,
+ /// we return the largest n such that
+ /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ),
+ /// and note that the interval is half-open. If flush == false,
+ /// define window_width as num_zeros / (2.0 * filter_cutoff_);
+ /// we return the largest n such that (n/samp_rate_out_) is in the interval
+ /// [ 0, input_num_samp/samp_rate_in_ - window_width ).
+ int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const;
+
+ /// Given an output-sample index, this function outputs to *first_samp_in the
+ /// first input-sample index that we have a weight on (may be negative),
+ /// and to *samp_out_wrapped the index into weights_ where we can get the
+ /// corresponding weights on the input.
+ inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in,
+ int32_t *samp_out_wrapped) const;
+
+ void SetRemainder(const float *input, int32_t input_dim);
+
+ private:
+ // The following variables are provided by the user.
+ int32_t samp_rate_in_;
+ int32_t samp_rate_out_;
+ float filter_cutoff_;
+ int32_t num_zeros_;
+
+ int32_t input_samples_in_unit_; ///< The number of input samples in the
+ ///< smallest repeating unit: num_samp_in_ =
+ ///< samp_rate_in_hz / Gcd(samp_rate_in_hz,
+ ///< samp_rate_out_hz)
+
+ int32_t output_samples_in_unit_; ///< The number of output samples in the
+ ///< smallest repeating unit: num_samp_out_
+ ///< = samp_rate_out_hz /
+ ///< Gcd(samp_rate_in_hz, samp_rate_out_hz)
+
+ /// The first input-sample index that we sum over, for this output-sample
+ /// index. May be negative; any truncation at the beginning is handled
+ /// separately. This is just for the first few output samples, but we can
+ /// extrapolate the correct input-sample index for arbitrary output samples.
+ std::vector<int32_t> first_index_;
+
+ /// Weights on the input samples, for this output-sample index.
+ std::vector<std::vector<float>> weights_;
+
+ // the following variables keep track of where we are in a particular signal,
+ // if it is being provided over multiple calls to Resample().
+
+ int64_t input_sample_offset_; ///< The number of input samples we have
+ ///< already received for this signal
+ ///< (including anything in remainder_)
+ int64_t output_sample_offset_; ///< The number of samples we have already
+ ///< output for this signal.
+ std::vector<float> input_remainder_; ///< A small trailing part of the
+ ///< previously seen input signal.
+};
diff --git a/funasr/runtime/onnxruntime/tester/CMakeLists.txt b/funasr/runtime/onnxruntime/tester/CMakeLists.txt
index f66319d..e3224e3 100644
--- a/funasr/runtime/onnxruntime/tester/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/tester/CMakeLists.txt
@@ -8,7 +8,7 @@
endif()
endif()
-set(EXTRA_LIBS rapidasr)
+set(EXTRA_LIBS funasr)
include_directories(${CMAKE_SOURCE_DIR}/include)
diff --git a/funasr/runtime/onnxruntime/tester/tester.cpp b/funasr/runtime/onnxruntime/tester/tester.cpp
index 35d534f..7257603 100644
--- a/funasr/runtime/onnxruntime/tester/tester.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester.cpp
@@ -5,7 +5,7 @@
#include <win_func.h>
#endif
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
#include <iostream>
#include <fstream>
@@ -26,7 +26,7 @@
// is quantize
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
- RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
+ FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
if (!AsrHanlde)
{
@@ -42,62 +42,22 @@
gettimeofday(&start, NULL);
float snippet_time = 0.0f;
- RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
+ FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
gettimeofday(&end, NULL);
if (Result)
{
- string msg = RapidAsrGetResult(Result, 0);
+ string msg = FunASRGetResult(Result, 0);
setbuf(stdout, NULL);
- cout << "Result: \"";
- cout << msg << "\"." << endl;
- snippet_time = RapidAsrGetRetSnippetTime(Result);
- RapidAsrFreeResult(Result);
+ printf("Result: %s \n", msg.c_str());
+ snippet_time = FunASRGetRetSnippetTime(Result);
+ FunASRFreeResult(Result);
}
else
{
cout <<"no return data!";
}
-
- //char* buff = nullptr;
- //int len = 0;
- //ifstream ifs(argv[2], std::ios::binary | std::ios::in);
- //if (ifs.is_open())
- //{
- // ifs.seekg(0, std::ios::end);
- // len = ifs.tellg();
- // ifs.seekg(0, std::ios::beg);
-
- // buff = new char[len];
-
- // ifs.read(buff, len);
-
-
- // //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
-
- // RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL);
- // //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
- // gettimeofday(&end, NULL);
- //
- // if (Result)
- // {
- // string msg = RapidAsrGetResult(Result, 0);
- // setbuf(stdout, NULL);
- // cout << "Result: \"";
- // cout << msg << endl;
- // cout << "\"." << endl;
- // snippet_time = RapidAsrGetRetSnippetTime(Result);
- // RapidAsrFreeResult(Result);
- // }
- // else
- // {
- // cout <<"no return data!";
- // }
-
- //
- //delete[]buff;
- //}
printf("Audio length %lfs.\n", (double)snippet_time);
seconds = (end.tv_sec - start.tv_sec);
@@ -105,9 +65,9 @@
printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
- RapidAsrUninit(AsrHanlde);
+ FunASRUninit(AsrHanlde);
return 0;
}
-
\ No newline at end of file
+
diff --git a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
index 9651900..dd79887 100644
--- a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -5,7 +5,7 @@
#include <win_func.h>
#endif
-#include "librapidasrapi.h"
+#include "libfunasrapi.h"
#include <iostream>
#include <fstream>
@@ -47,7 +47,7 @@
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
- RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
+ FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
if (!AsrHanlde)
{
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
@@ -61,7 +61,7 @@
// warm up
for (size_t i = 0; i < 30; i++)
{
- RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
+ FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
}
// forward
@@ -72,19 +72,19 @@
for (size_t i = 0; i < wav_list.size(); i++)
{
gettimeofday(&start, NULL);
- RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
+ FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
total_time += taking_micros;
if(Result){
- string msg = RapidAsrGetResult(Result, 0);
- printf("Result: %s \n", msg);
+ string msg = FunASRGetResult(Result, 0);
+ printf("Result: %s \n", msg.c_str());
- snippet_time = RapidAsrGetRetSnippetTime(Result);
+ snippet_time = FunASRGetRetSnippetTime(Result);
total_length += snippet_time;
- RapidAsrFreeResult(Result);
+ FunASRFreeResult(Result);
}else{
cout <<"No return data!";
}
@@ -94,6 +94,6 @@
printf("total_time_comput %ld ms.\n", total_time / 1000);
printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
- RapidAsrUninit(AsrHanlde);
+ FunASRUninit(AsrHanlde);
return 0;
}
diff --git a/funasr/runtime/python/grpc/grpc_main_client.py b/funasr/runtime/python/grpc/grpc_main_client.py
new file mode 100644
index 0000000..b6491df
--- /dev/null
+++ b/funasr/runtime/python/grpc/grpc_main_client.py
@@ -0,0 +1,62 @@
+import grpc
+import json
+import time
+import asyncio
+import soundfile as sf
+import argparse
+
+from grpc_client import transcribe_audio_bytes
+from paraformer_pb2_grpc import ASRStub
+
+# send the audio data once
+async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
+ with grpc.insecure_channel(grpc_uri) as channel:
+ stub = ASRStub(channel)
+ for line in wav_scp:
+ wav_file = line.split()[1]
+ wav, _ = sf.read(wav_file, dtype='int16')
+
+ b = time.time()
+ response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
+ resp = response.next()
+ text = ''
+ if 'decoding' == resp.action:
+ resp = response.next()
+ if 'finish' == resp.action:
+ text = json.loads(resp.sentence)['text']
+ response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
+ res= {'text': text, 'time': time.time() - b}
+ print(res)
+
+async def test(args):
+ wav_scp = open(args.wav_scp, "r").readlines()
+ uri = '{}:{}'.format(args.host, args.port)
+ res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host",
+ type=str,
+ default="127.0.0.1",
+ required=False,
+ help="grpc server host ip")
+ parser.add_argument("--port",
+ type=int,
+ default=10108,
+ required=False,
+ help="grpc server port")
+ parser.add_argument("--user_allowed",
+ type=str,
+ default="project1_user1",
+ help="allowed user for grpc client")
+ parser.add_argument("--sample_rate",
+ type=int,
+ default=16000,
+ help="audio sample_rate from client")
+ parser.add_argument("--wav_scp",
+ type=str,
+ required=True,
+ help="audio wav scp")
+ args = parser.parse_args()
+
+ asyncio.run(test(args))
diff --git a/funasr/runtime/python/grpc/grpc_server.py b/funasr/runtime/python/grpc/grpc_server.py
index d0be6f0..4fd4f95 100644
--- a/funasr/runtime/python/grpc/grpc_server.py
+++ b/funasr/runtime/python/grpc/grpc_server.py
@@ -109,7 +109,7 @@
else:
asr_result = ""
elif self.backend == "onnxruntime":
- from rapid_paraformer.utils.frontend import load_bytes
+ from funasr_onnx.utils.frontend import load_bytes
array = load_bytes(tmp_data)
asr_result = self.inference_16k_pipeline(array)[0]
end_time = int(round(time.time() * 1000))
diff --git a/funasr/runtime/python/libtorch/README.md b/funasr/runtime/python/libtorch/README.md
index aeb2eae..27b5f86 100644
--- a/funasr/runtime/python/libtorch/README.md
+++ b/funasr/runtime/python/libtorch/README.md
@@ -2,8 +2,6 @@
[FunASR](https://github.com/alibaba-damo-academy/FunASR) hopes to build a bridge between academic research and industrial applications on speech recognition. By supporting the training & finetuning of the industrial-grade speech recognition model released on ModelScope, researchers and developers can conduct research and production of speech recognition models more conveniently, and promote the development of speech recognition ecology. ASR for Fun锛�
-### Introduction
-- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
### Steps:
1. Export the model.
@@ -25,15 +23,20 @@
install from pip
```shell
- pip install --upgrade funasr_torch -i https://pypi.Python.org/simple
+ pip install -U funasr_torch
+ # For the users in China, you could install with the command:
+ # pip install -U funasr_torch -i https://mirror.sjtu.edu.cn/pypi/web/simple
+
```
or install from source code
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/libtorch
- python setup.py build
- python setup.py install
+ pip install -e ./
+ # For the users in China, you could install with the command:
+ # pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+
```
3. Run the demo.
diff --git a/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index e169087..9954daa 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -23,7 +23,6 @@
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
- pred_bias: int = 1,
quantize: bool = False,
intra_op_num_threads: int = 1,
):
@@ -48,7 +47,10 @@
self.batch_size = batch_size
self.device_id = device_id
self.plot_timestamp_to = plot_timestamp_to
- self.pred_bias = pred_bias
+ if "predictor_bias" in config['model_conf'].keys():
+ self.pred_bias = config['model_conf']['predictor_bias']
+ else:
+ self.pred_bias = 0
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
index 2f09de8..86e78bc 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -23,9 +23,11 @@
):
check_argument_types()
- # self.token_list = self.load_token(token_path)
self.token_list = token_list
self.unk_symbol = token_list[-1]
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
+ self.unk_id = self.token2id[self.unk_symbol]
+
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@@ -38,13 +40,8 @@
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
- token2id = {v: i for i, v in enumerate(self.token_list)}
- if self.unk_symbol not in token2id:
- raise TokenIDConverterError(
- f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
- )
- unk_id = token2id[self.unk_symbol]
- return [token2id.get(i, unk_id) for i in tokens]
+
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
@@ -134,7 +131,7 @@
@functools.lru_cache()
-def get_logger(name='torch_paraformer'):
+def get_logger(name='funasr_torch'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
diff --git a/funasr/runtime/python/libtorch/setup.py b/funasr/runtime/python/libtorch/setup.py
index 1560950..fd8b151 100644
--- a/funasr/runtime/python/libtorch/setup.py
+++ b/funasr/runtime/python/libtorch/setup.py
@@ -15,10 +15,10 @@
setuptools.setup(
name='funasr_torch',
- version='0.0.3',
+ version='0.0.4',
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab, Alibaba Group, China",
+ author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license="The MIT License",
@@ -31,7 +31,7 @@
"PyYAML>=5.1.2", "torch-quant >= 0.4.0"],
packages=find_packages(include=["torch_paraformer*"]),
keywords=[
- 'funasr,paraformer, funasr_torch'
+ 'funasr, paraformer, funasr_torch'
],
classifiers=[
'Programming Language :: Python :: 3.6',
diff --git a/funasr/runtime/python/onnxruntime/README.md b/funasr/runtime/python/onnxruntime/README.md
index e19e3a2..87510fa 100644
--- a/funasr/runtime/python/onnxruntime/README.md
+++ b/funasr/runtime/python/onnxruntime/README.md
@@ -1,10 +1,6 @@
## Using funasr with ONNXRuntime
-### Introduction
-- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
-
-
### Steps:
1. Export the model.
- Command: (`Tips`: torch >= 1.11.0 is required.)
@@ -25,7 +21,10 @@
install from pip
```shell
-pip install --upgrade funasr_onnx -i https://pypi.Python.org/simple
+pip install -U funasr_onnx
+# For the users in China, you could install with the command:
+# pip install -U funasr_onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
+
```
or install from source code
@@ -33,8 +32,10 @@
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/onnxruntime
-python setup.py build
-python setup.py install
+pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+
```
3. Run the demo.
diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py
index f0f39d7..8fc82f1 100644
--- a/funasr/runtime/python/onnxruntime/demo.py
+++ b/funasr/runtime/python/onnxruntime/demo.py
@@ -1,5 +1,6 @@
from funasr_onnx import Paraformer
+
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0) # cpu
@@ -7,7 +8,6 @@
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
-
wav_path = "YourPath/xx.wav"
diff --git a/funasr/runtime/python/onnxruntime/demo_punc_offline.py b/funasr/runtime/python/onnxruntime/demo_punc_offline.py
new file mode 100644
index 0000000..469adda
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_punc_offline.py
@@ -0,0 +1,8 @@
+from funasr_onnx import CT_Transformer
+
+model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model = CT_Transformer(model_dir)
+
+text_in="璺ㄥ娌虫祦鏄吇鑲叉部宀镐汉姘戠殑鐢熷懡涔嬫簮闀挎湡浠ユ潵涓哄府鍔╀笅娓稿湴鍖洪槻鐏惧噺鐏句腑鏂规妧鏈汉鍛樺湪涓婃父鍦板尯鏋佷负鎭跺姡鐨勮嚜鐒舵潯浠朵笅鍏嬫湇宸ㄥぇ鍥伴毦鐢氳嚦鍐掔潃鐢熷懡鍗遍櫓鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒跺嚒鏄腑鏂硅兘鍋氱殑鎴戜滑閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛﹁鍒掑拰璁鸿瘉鍏奸【涓婁笅娓哥殑鍒╃泭"
+result = model(text_in)
+print(result[0])
diff --git a/funasr/runtime/python/onnxruntime/demo_punc_online.py b/funasr/runtime/python/onnxruntime/demo_punc_online.py
new file mode 100644
index 0000000..63f2f5e
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_punc_online.py
@@ -0,0 +1,15 @@
+from funasr_onnx import CT_Transformer_VadRealtime
+
+model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model = CT_Transformer_VadRealtime(model_dir)
+
+text_in = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦>闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+
+vads = text_in.split("|")
+rec_result_all=""
+param_dict = {"cache": []}
+for vad in vads:
+ result = model(vad, param_dict=param_dict)
+ rec_result_all += result[0]
+
+print(rec_result_all)
diff --git a/funasr/runtime/python/onnxruntime/demo_vad_offline.py b/funasr/runtime/python/onnxruntime/demo_vad_offline.py
new file mode 100644
index 0000000..ea76ec3
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_vad_offline.py
@@ -0,0 +1,11 @@
+import soundfile
+from funasr_onnx import Fsmn_vad
+
+
+model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
+model = Fsmn_vad(model_dir)
+
+#offline vad
+result = model(wav_path)
+print(result)
diff --git a/funasr/runtime/python/onnxruntime/demo_vad_online.py b/funasr/runtime/python/onnxruntime/demo_vad_online.py
new file mode 100644
index 0000000..1ab4d9d
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_vad_online.py
@@ -0,0 +1,28 @@
+import soundfile
+from funasr_onnx import Fsmn_vad_online
+
+
+model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
+model = Fsmn_vad_online(model_dir)
+
+
+##online vad
+speech, sample_rate = soundfile.read(wav_path)
+speech_length = speech.shape[0]
+#
+sample_offset = 0
+step = 1600
+param_dict = {'in_cache': []}
+for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+ if sample_offset + step >= speech_length - 1:
+ step = speech_length - sample_offset
+ is_final = True
+ else:
+ is_final = False
+ param_dict['is_final'] = is_final
+ segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
+ param_dict=param_dict)
+ if segments_result:
+ print(segments_result)
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
index 647f9fa..7d8d662 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,2 +1,6 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
+from .vad_bin import Fsmn_vad
+from .vad_bin import Fsmn_vad_online
+from .punc_bin import CT_Transformer
+from .punc_bin import CT_Transformer_VadRealtime
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index cbdb8d9..e3ef8c7 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -23,7 +23,6 @@
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
- pred_bias: int = 1,
quantize: bool = False,
intra_op_num_threads: int = 4,
):
@@ -47,7 +46,10 @@
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = batch_size
self.plot_timestamp_to = plot_timestamp_to
- self.pred_bias = pred_bias
+ if "predictor_bias" in config['model_conf'].keys():
+ self.pred_bias = config['model_conf']['predictor_bias']
+ else:
+ self.pred_bias = 0
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
new file mode 100644
index 0000000..bbbb913
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -0,0 +1,259 @@
+# -*- encoding: utf-8 -*-
+
+import os.path
+from pathlib import Path
+from typing import List, Union, Tuple
+import numpy as np
+
+from .utils.utils import (ONNXRuntimeError,
+ OrtInferSession, get_logger,
+ read_yaml)
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+logging = get_logger()
+
+
+class CT_Transformer():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4
+ ):
+
+ if not Path(model_dir).exists():
+ raise FileNotFoundError(f'{model_dir} does not exist.')
+
+ model_file = os.path.join(model_dir, 'model.onnx')
+ if quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+ config_file = os.path.join(model_dir, 'punc.yaml')
+ config = read_yaml(config_file)
+
+ self.converter = TokenIDConverter(config['token_list'])
+ self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.batch_size = 1
+ self.punc_list = config['punc_list']
+ self.period = 0
+ for i in range(len(self.punc_list)):
+ if self.punc_list[i] == ",":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "?":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "銆�":
+ self.period = i
+
+ def __call__(self, text: Union[list, str], split_size=20):
+ split_text = code_mix_split_words(text)
+ split_text_id = self.converter.tokens2ids(split_text)
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+ assert len(mini_sentences) == len(mini_sentences_id)
+ cache_sent = []
+ cache_sent_id = []
+ new_mini_sentence = ""
+ new_mini_sentence_punc = []
+ cache_pop_trigger_limit = 200
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
+ data = {
+ "text": mini_sentence_id[None,:],
+ "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
+ }
+ try:
+ outputs = self.infer(data['text'], data['text_lengths'])
+ y = outputs[0]
+ punctuations = np.argmax(y,axis=-1)[0]
+ assert punctuations.size == len(mini_sentence)
+ except ONNXRuntimeError:
+ logging.warning("error")
+
+ # Search for the last Period/QuestionMark as cache
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # The sentence it too long, cut off at a comma.
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.period
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+
+ new_mini_sentence_punc += [int(x) for x in punctuations]
+ words_with_punc = []
+ for i in range(len(mini_sentence)):
+ if i > 0:
+ if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
+ words_with_punc.append(mini_sentence[i])
+ if self.punc_list[punctuations[i]] != "_":
+ words_with_punc.append(self.punc_list[punctuations[i]])
+ new_mini_sentence += "".join(words_with_punc)
+ # Add Period for the end of the sentence
+ new_mini_sentence_out = new_mini_sentence
+ new_mini_sentence_punc_out = new_mini_sentence_punc
+ if mini_sentence_i == len(mini_sentences) - 1:
+ if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+ new_mini_sentence_out = new_mini_sentence + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ return new_mini_sentence_out, new_mini_sentence_punc_out
+
+ def infer(self, feats: np.ndarray,
+ feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer([feats, feats_len])
+ return outputs
+
+
+class CT_Transformer_VadRealtime(CT_Transformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4
+ ):
+ super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
+
+ def __call__(self, text: str, param_dict: map, split_size=20):
+ cache_key = "cache"
+ assert cache_key in param_dict
+ cache = param_dict[cache_key]
+ if cache is not None and len(cache) > 0:
+ precache = "".join(cache)
+ else:
+ precache = ""
+ cache = []
+ full_text = precache + text
+ split_text = code_mix_split_words(full_text)
+ split_text_id = self.converter.tokens2ids(split_text)
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+ new_mini_sentence_punc = []
+ assert len(mini_sentences) == len(mini_sentences_id)
+
+ cache_sent = []
+ cache_sent_id = np.array([], dtype='int32')
+ sentence_punc_list = []
+ sentence_words_list = []
+ cache_pop_trigger_limit = 200
+ skip_num = 0
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ text_length = len(mini_sentence_id)
+ data = {
+ "input": mini_sentence_id[None,:],
+ "text_lengths": np.array([text_length], dtype='int32'),
+ "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
+ "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+ }
+ try:
+ outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
+ y = outputs[0]
+ punctuations = np.argmax(y,axis=-1)[0]
+ assert punctuations.size == len(mini_sentence)
+ except ONNXRuntimeError:
+ logging.warning("error")
+
+ # Search for the last Period/QuestionMark as cache
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # The sentence it too long, cut off at a comma.
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.period
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+
+ punctuations_np = [int(x) for x in punctuations]
+ new_mini_sentence_punc += punctuations_np
+ sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+ sentence_words_list += mini_sentence
+
+ assert len(sentence_punc_list) == len(sentence_words_list)
+ words_with_punc = []
+ sentence_punc_list_out = []
+ for i in range(0, len(sentence_words_list)):
+ if i > 0:
+ if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+ sentence_words_list[i] = " " + sentence_words_list[i]
+ if skip_num < len(cache):
+ skip_num += 1
+ else:
+ words_with_punc.append(sentence_words_list[i])
+ if skip_num >= len(cache):
+ sentence_punc_list_out.append(sentence_punc_list[i])
+ if sentence_punc_list[i] != "_":
+ words_with_punc.append(sentence_punc_list[i])
+ sentence_out = "".join(words_with_punc)
+
+ sentenceEnd = -1
+ for i in range(len(sentence_punc_list) - 2, 1, -1):
+ if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+ sentenceEnd = i
+ break
+ cache_out = sentence_words_list[sentenceEnd + 1:]
+ if sentence_out[-1] in self.punc_list:
+ sentence_out = sentence_out[:-1]
+ sentence_punc_list_out[-1] = "_"
+ param_dict[cache_key] = cache_out
+ return sentence_out, sentence_punc_list_out, cache_out
+
+ def vad_mask(self, size, vad_pos, dtype=np.bool):
+ """Create mask for decoder self-attention.
+
+ :param int size: size of mask
+ :param int vad_pos: index of vad index
+ :param torch.dtype dtype: result dtype
+ :rtype: torch.Tensor (B, Lmax, Lmax)
+ """
+ ret = np.ones((size, size), dtype=dtype)
+ if vad_pos <= 0 or vad_pos >= size:
+ return ret
+ sub_corner = np.zeros(
+ (vad_pos - 1, size - vad_pos), dtype=dtype)
+ ret[0:vad_pos - 1, vad_pos:] = sub_corner
+ return ret
+
+ def infer(self, feats: np.ndarray,
+ feats_len: np.ndarray,
+ vad_mask: np.ndarray,
+ sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
+ return outputs
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
new file mode 100644
index 0000000..029f529
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -0,0 +1,616 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+
+import math
+import numpy as np
+
+class VadStateMachine(Enum):
+ kVadInStateStartPointNotDetected = 1
+ kVadInStateInSpeechSegment = 2
+ kVadInStateEndPointDetected = 3
+
+
+class FrameState(Enum):
+ kFrameStateInvalid = -1
+ kFrameStateSpeech = 1
+ kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+ kChangeStateSpeech2Speech = 0
+ kChangeStateSpeech2Sil = 1
+ kChangeStateSil2Sil = 2
+ kChangeStateSil2Speech = 3
+ kChangeStateNoBegin = 4
+ kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+ kVadSingleUtteranceDetectMode = 0
+ kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+ def __init__(
+ self,
+ sample_rate: int = 16000,
+ detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+ snr_mode: int = 0,
+ max_end_silence_time: int = 800,
+ max_start_silence_time: int = 3000,
+ do_start_point_detection: bool = True,
+ do_end_point_detection: bool = True,
+ window_size_ms: int = 200,
+ sil_to_speech_time_thres: int = 150,
+ speech_to_sil_time_thres: int = 150,
+ speech_2_noise_ratio: float = 1.0,
+ do_extend: int = 1,
+ lookback_time_start_point: int = 200,
+ lookahead_time_end_point: int = 100,
+ max_single_segment_time: int = 60000,
+ nn_eval_block_size: int = 8,
+ dcd_block_size: int = 4,
+ snr_thres: int = -100.0,
+ noise_frame_num_used_for_snr: int = 100,
+ decibel_thres: int = -100.0,
+ speech_noise_thres: float = 0.6,
+ fe_prior_thres: float = 1e-4,
+ silence_pdf_num: int = 1,
+ sil_pdf_ids: List[int] = [0],
+ speech_noise_thresh_low: float = -0.1,
+ speech_noise_thresh_high: float = 0.3,
+ output_frame_probs: bool = False,
+ frame_in_ms: int = 10,
+ frame_length_ms: int = 25,
+ ):
+ self.sample_rate = sample_rate
+ self.detect_mode = detect_mode
+ self.snr_mode = snr_mode
+ self.max_end_silence_time = max_end_silence_time
+ self.max_start_silence_time = max_start_silence_time
+ self.do_start_point_detection = do_start_point_detection
+ self.do_end_point_detection = do_end_point_detection
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time_thres = sil_to_speech_time_thres
+ self.speech_to_sil_time_thres = speech_to_sil_time_thres
+ self.speech_2_noise_ratio = speech_2_noise_ratio
+ self.do_extend = do_extend
+ self.lookback_time_start_point = lookback_time_start_point
+ self.lookahead_time_end_point = lookahead_time_end_point
+ self.max_single_segment_time = max_single_segment_time
+ self.nn_eval_block_size = nn_eval_block_size
+ self.dcd_block_size = dcd_block_size
+ self.snr_thres = snr_thres
+ self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+ self.decibel_thres = decibel_thres
+ self.speech_noise_thres = speech_noise_thres
+ self.fe_prior_thres = fe_prior_thres
+ self.silence_pdf_num = silence_pdf_num
+ self.sil_pdf_ids = sil_pdf_ids
+ self.speech_noise_thresh_low = speech_noise_thresh_low
+ self.speech_noise_thresh_high = speech_noise_thresh_high
+ self.output_frame_probs = output_frame_probs
+ self.frame_in_ms = frame_in_ms
+ self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+ def __init__(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+ def Reset(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+
+class E2EVadFrameProb(object):
+ def __init__(self):
+ self.noise_prob = 0.0
+ self.speech_prob = 0.0
+ self.score = 0.0
+ self.frame_id = 0
+ self.frm_state = 0
+
+
+class WindowDetector(object):
+ def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+ speech_to_sil_time: int, frame_size_ms: int):
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time = sil_to_speech_time
+ self.speech_to_sil_time = speech_to_sil_time
+ self.frame_size_ms = frame_size_ms
+
+ self.win_size_frame = int(window_size_ms / frame_size_ms)
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame # 鍒濆鍖栫獥
+
+ self.cur_win_pos = 0
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+ self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def Reset(self) -> None:
+ self.cur_win_pos = 0
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def GetWinSize(self) -> int:
+ return int(self.win_size_frame)
+
+ def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+ cur_frame_state = FrameState.kFrameStateSil
+ if frameState == FrameState.kFrameStateSpeech:
+ cur_frame_state = 1
+ elif frameState == FrameState.kFrameStateSil:
+ cur_frame_state = 0
+ else:
+ return AudioChangeState.kChangeStateInvalid
+ self.win_sum -= self.win_state[self.cur_win_pos]
+ self.win_sum += cur_frame_state
+ self.win_state[self.cur_win_pos] = cur_frame_state
+ self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+ if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSpeech
+ return AudioChangeState.kChangeStateSil2Speech
+
+ if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSil
+ return AudioChangeState.kChangeStateSpeech2Sil
+
+ if self.pre_frame_state == FrameState.kFrameStateSil:
+ return AudioChangeState.kChangeStateSil2Sil
+ if self.pre_frame_state == FrameState.kFrameStateSpeech:
+ return AudioChangeState.kChangeStateSpeech2Speech
+ return AudioChangeState.kChangeStateInvalid
+
+ def FrameSizeMs(self) -> int:
+ return int(self.frame_size_ms)
+
+
+class E2EVadModel():
+ def __init__(self, vad_post_args: Dict[str, Any]):
+ super(E2EVadModel, self).__init__()
+ self.vad_opts = VADXOptions(**vad_post_args)
+ self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+ self.vad_opts.sil_to_speech_time_thres,
+ self.vad_opts.speech_to_sil_time_thres,
+ self.vad_opts.frame_in_ms)
+ # self.encoder = encoder
+ # init variables
+ self.is_final = False
+ self.data_buf_start_frame = 0
+ self.frm_cnt = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.continous_silence_frame_count = 0
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.number_end_time_detected = 0
+ self.sil_frame = 0
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+ self.noise_average_decibel = -100.0
+ self.pre_end_silence_detected = False
+ self.next_seg = True
+
+ self.output_data_buf = []
+ self.output_data_buf_offset = 0
+ self.frame_probs = []
+ self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
+ self.scores = None
+ self.max_time_out = False
+ self.decibel = []
+ self.data_buf = None
+ self.data_buf_all = None
+ self.waveform = None
+ self.ResetDetection()
+
+ def AllResetDetection(self):
+ self.is_final = False
+ self.data_buf_start_frame = 0
+ self.frm_cnt = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.continous_silence_frame_count = 0
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.number_end_time_detected = 0
+ self.sil_frame = 0
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+ self.noise_average_decibel = -100.0
+ self.pre_end_silence_detected = False
+ self.next_seg = True
+
+ self.output_data_buf = []
+ self.output_data_buf_offset = 0
+ self.frame_probs = []
+ self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
+ self.scores = None
+ self.max_time_out = False
+ self.decibel = []
+ self.data_buf = None
+ self.data_buf_all = None
+ self.waveform = None
+ self.ResetDetection()
+
+ def ResetDetection(self):
+ self.continous_silence_frame_count = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.windows_detector.Reset()
+ self.sil_frame = 0
+ self.frame_probs = []
+
+ def ComputeDecibel(self) -> None:
+ frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+ frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+ if self.data_buf_all is None:
+ self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
+ self.data_buf = self.data_buf_all
+ else:
+ self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0]))
+ for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+ self.decibel.append(
+ 10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \
+ 0.000001))
+
+ def ComputeScores(self, scores: np.ndarray) -> None:
+ # scores = self.encoder(feats, in_cache) # return B * T * D
+ self.vad_opts.nn_eval_block_size = scores.shape[1]
+ self.frm_cnt += scores.shape[1] # count total frames
+ if self.scores is None:
+ self.scores = scores # the first calculation
+ else:
+ self.scores = np.concatenate((self.scores, scores), axis=1)
+
+ def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
+ while self.data_buf_start_frame < frame_idx:
+ if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+ self.data_buf_start_frame += 1
+ self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+
+ def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+ last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
+ self.PopDataBufTillFrame(start_frm)
+ expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+ if last_frm_is_end_point:
+ extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+ self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+ expected_sample_number += int(extra_sample)
+ if end_point_is_sent_end:
+ expected_sample_number = max(expected_sample_number, len(self.data_buf))
+ if len(self.data_buf) < expected_sample_number:
+ print('error in calling pop data_buf\n')
+
+ if len(self.output_data_buf) == 0 or first_frm_is_start_point:
+ self.output_data_buf.append(E2EVadSpeechBufWithDoa())
+ self.output_data_buf[-1].Reset()
+ self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+ self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
+ self.output_data_buf[-1].doa = 0
+ cur_seg = self.output_data_buf[-1]
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+ print('warning\n')
+ out_pos = len(cur_seg.buffer) # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+ data_to_pop = 0
+ if end_point_is_sent_end:
+ data_to_pop = expected_sample_number
+ else:
+ data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+ if data_to_pop > len(self.data_buf):
+ print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
+ data_to_pop = len(self.data_buf)
+ expected_sample_number = len(self.data_buf)
+
+ cur_seg.doa = 0
+ for sample_cpy_out in range(0, data_to_pop):
+ # cur_seg.buffer[out_pos ++] = data_buf_.back();
+ out_pos += 1
+ for sample_cpy_out in range(data_to_pop, expected_sample_number):
+ # cur_seg.buffer[out_pos++] = data_buf_.back()
+ out_pos += 1
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+ print('Something wrong with the VAD algorithm\n')
+ self.data_buf_start_frame += frm_cnt
+ cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+ if first_frm_is_start_point:
+ cur_seg.contain_seg_start_point = True
+ if last_frm_is_end_point:
+ cur_seg.contain_seg_end_point = True
+
+ def OnSilenceDetected(self, valid_frame: int):
+ self.lastest_confirmed_silence_frame = valid_frame
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ self.PopDataBufTillFrame(valid_frame)
+ # silence_detected_callback_
+ # pass
+
+ def OnVoiceDetected(self, valid_frame: int) -> None:
+ self.latest_confirmed_speech_frame = valid_frame
+ self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
+
+ def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
+ if self.vad_opts.do_start_point_detection:
+ pass
+ if self.confirmed_start_frame != -1:
+ print('not reset vad properly\n')
+ else:
+ self.confirmed_start_frame = start_frame
+
+ if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
+
+ def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
+ for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
+ self.OnVoiceDetected(t)
+ if self.vad_opts.do_end_point_detection:
+ pass
+ if self.confirmed_end_frame != -1:
+ print('not reset vad properly\n')
+ else:
+ self.confirmed_end_frame = end_frame
+ if not fake_result:
+ self.sil_frame = 0
+ self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
+ self.number_end_time_detected += 1
+
+ def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
+ if is_final_frame:
+ self.OnVoiceEnd(cur_frm_idx, False, True)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+
+ def GetLatency(self) -> int:
+ return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
+
+ def LatencyFrmNumAtStartPoint(self) -> int:
+ vad_latency = self.windows_detector.GetWinSize()
+ if self.vad_opts.do_extend:
+ vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+ return vad_latency
+
+ def GetFrameState(self, t: int) -> FrameState:
+ frame_state = FrameState.kFrameStateInvalid
+ cur_decibel = self.decibel[t]
+ cur_snr = cur_decibel - self.noise_average_decibel
+ # for each frame, calc log posterior probability of each state
+ if cur_decibel < self.vad_opts.decibel_thres:
+ frame_state = FrameState.kFrameStateSil
+ self.DetectOneFrame(frame_state, t, False)
+ return frame_state
+
+ sum_score = 0.0
+ noise_prob = 0.0
+ assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
+ if len(self.sil_pdf_ids) > 0:
+ assert len(self.scores) == 1 # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+ sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
+ sum_score = sum(sil_pdf_scores)
+ noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+ total_score = 1.0
+ sum_score = total_score - sum_score
+ speech_prob = math.log(sum_score)
+ if self.vad_opts.output_frame_probs:
+ frame_prob = E2EVadFrameProb()
+ frame_prob.noise_prob = noise_prob
+ frame_prob.speech_prob = speech_prob
+ frame_prob.score = sum_score
+ frame_prob.frame_id = t
+ self.frame_probs.append(frame_prob)
+ if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
+ if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+ frame_state = FrameState.kFrameStateSpeech
+ else:
+ frame_state = FrameState.kFrameStateSil
+ else:
+ frame_state = FrameState.kFrameStateSil
+ if self.noise_average_decibel < -99.9:
+ self.noise_average_decibel = cur_decibel
+ else:
+ self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
+ self.vad_opts.noise_frame_num_used_for_snr
+ - 1)) / self.vad_opts.noise_frame_num_used_for_snr
+
+ return frame_state
+
+ def __call__(self, score: np.ndarray, waveform: np.ndarray,
+ is_final: bool = False, max_end_sil: int = 800, online: bool = False
+ ):
+ self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
+ self.waveform = waveform # compute decibel for each frame
+ self.ComputeDecibel()
+ self.ComputeScores(score)
+ if not is_final:
+ self.DetectCommonFrames()
+ else:
+ self.DetectLastFrames()
+ segments = []
+ for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now
+ segment_batch = []
+ if len(self.output_data_buf) > 0:
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+ if online:
+ if not self.output_data_buf[i].contain_seg_start_point:
+ continue
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+ if self.output_data_buf[i].contain_seg_end_point:
+ end_ms = self.output_data_buf[i].end_ms
+ self.next_seg = True
+ self.output_data_buf_offset += 1
+ else:
+ end_ms = -1
+ self.next_seg = False
+ else:
+ if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
+ continue
+ start_ms = self.output_data_buf[i].start_ms
+ end_ms = self.output_data_buf[i].end_ms
+ self.output_data_buf_offset += 1
+ segment = [start_ms, end_ms]
+ segment_batch.append(segment)
+
+ if segment_batch:
+ segments.append(segment_batch)
+ if is_final:
+ # reset class variables and clear the dict for the next query
+ self.AllResetDetection()
+ return segments
+
+ def DetectCommonFrames(self) -> int:
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+ return 0
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+ frame_state = FrameState.kFrameStateInvalid
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+
+ return 0
+
+ def DetectLastFrames(self) -> int:
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+ return 0
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+ frame_state = FrameState.kFrameStateInvalid
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ if i != 0:
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+ else:
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
+
+ return 0
+
+ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
+ tmp_cur_frm_state = FrameState.kFrameStateInvalid
+ if cur_frm_state == FrameState.kFrameStateSpeech:
+ if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+ tmp_cur_frm_state = FrameState.kFrameStateSpeech
+ else:
+ tmp_cur_frm_state = FrameState.kFrameStateSil
+ elif cur_frm_state == FrameState.kFrameStateSil:
+ tmp_cur_frm_state = FrameState.kFrameStateSil
+ state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
+ frm_shift_in_ms = self.vad_opts.frame_in_ms
+ if AudioChangeState.kChangeStateSil2Speech == state_change:
+ silence_frame_count = self.continous_silence_frame_count
+ self.continous_silence_frame_count = 0
+ self.pre_end_silence_detected = False
+ start_frame = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+ self.OnVoiceStart(start_frame)
+ self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+ for t in range(start_frame + 1, cur_frm_idx + 1):
+ self.OnVoiceDetected(t)
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
+ self.OnVoiceDetected(t)
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+ self.continous_silence_frame_count = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ pass
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+ self.continous_silence_frame_count = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.max_time_out = True
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSil2Sil == state_change:
+ self.continous_silence_frame_count += 1
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ # silence timeout, return zero length decision
+ if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+ self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+ or (is_final_frame and self.number_end_time_detected == 0):
+ for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
+ self.OnSilenceDetected(t)
+ self.OnVoiceStart(0, True)
+ self.OnVoiceEnd(0, True, False);
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ else:
+ if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
+ self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
+ lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+ if self.vad_opts.do_extend:
+ lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+ lookback_frame -= 1
+ lookback_frame = max(0, lookback_frame)
+ self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif self.vad_opts.do_extend and not is_final_frame:
+ if self.continous_silence_frame_count <= int(
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+ self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+ self.ResetDetection()
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
index 11a8644..c92db4e 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
+import copy
import numpy as np
from typeguard import check_argument_types
@@ -153,6 +154,187 @@
cmvn = np.array([means, vars])
return cmvn
+
+class WavFrontendOnline(WavFrontend):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
+ # add variables
+ self.frame_sample_length = int(self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000)
+ self.frame_shift_sample_length = int(self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000)
+ self.waveform = None
+ self.reserve_waveforms = None
+ self.input_cache = None
+ self.lfr_splice_cache = []
+
+ @staticmethod
+ # inputs has catted the cache
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
+ np.ndarray, np.ndarray, int]:
+ """
+ Apply lfr with data
+ """
+
+ LFR_inputs = []
+ T = inputs.shape[0] # include the right context
+ T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
+ splice_idx = T_lfr
+ for i in range(T_lfr):
+ if lfr_m <= T - i * lfr_n:
+ LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
+ else: # process last LFR frame
+ if is_final:
+ num_padding = lfr_m - (T - i * lfr_n)
+ frame = (inputs[i * lfr_n:]).reshape(-1)
+ for _ in range(num_padding):
+ frame = np.hstack((frame, inputs[-1]))
+ LFR_inputs.append(frame)
+ else:
+ # update splice_idx and break the circle
+ splice_idx = i
+ break
+ splice_idx = min(T - 1, splice_idx * lfr_n)
+ lfr_splice_cache = inputs[splice_idx:, :]
+ LFR_outputs = np.vstack(LFR_inputs)
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
+
+ @staticmethod
+ def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
+ frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
+ return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
+
+
+ def fbank(
+ self,
+ input: np.ndarray,
+ input_lengths: np.ndarray
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ self.fbank_fn = knf.OnlineFbank(self.opts)
+ batch_size = input.shape[0]
+ if self.input_cache is None:
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
+ input = np.concatenate((self.input_cache, input), axis=1)
+ frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
+ # update self.in_cache
+ self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
+ waveforms = np.empty(0, dtype=np.int16)
+ feats_pad = np.empty(0, dtype=np.float32)
+ feats_lens = np.empty(0, dtype=np.int32)
+ if frame_num:
+ waveforms = []
+ feats = []
+ feats_lens = []
+ for i in range(batch_size):
+ waveform = input[i]
+ waveforms.append(
+ waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
+ waveform = waveform * (1 << 15)
+
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
+ frames = self.fbank_fn.num_frames_ready
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
+ for i in range(frames):
+ mat[i, :] = self.fbank_fn.get_frame(i)
+ feat = mat.astype(np.float32)
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
+ feats.append(mat)
+ feats_lens.append(feat_len)
+
+ waveforms = np.stack(waveforms)
+ feats_lens = np.array(feats_lens)
+ feats_pad = np.array(feats)
+ self.fbanks = feats_pad
+ self.fbanks_lens = copy.deepcopy(feats_lens)
+ return waveforms, feats_pad, feats_lens
+
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
+ return self.fbanks, self.fbanks_lens
+
+ def lfr_cmvn(
+ self,
+ input: np.ndarray,
+ input_lengths: np.ndarray,
+ is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
+ batch_size = input.shape[0]
+ feats = []
+ feats_lens = []
+ lfr_splice_frame_idxs = []
+ for i in range(batch_size):
+ mat = input[i, :input_lengths[i], :]
+ lfr_splice_frame_idx = -1
+ if self.lfr_m != 1 or self.lfr_n != 1:
+ # update self.lfr_splice_cache in self.apply_lfr
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
+ is_final)
+ if self.cmvn_file is not None:
+ mat = self.apply_cmvn(mat)
+ feat_length = mat.shape[0]
+ feats.append(mat)
+ feats_lens.append(feat_length)
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
+
+ feats_lens = np.array(feats_lens)
+ feats_pad = np.array(feats)
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
+
+
+ def extract_fbank(
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ batch_size = input.shape[0]
+ assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
+ waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
+ if feats.shape[0]:
+ self.waveforms = waveforms if self.reserve_waveforms is None else np.concatenate(
+ (self.reserve_waveforms, waveforms), axis=1)
+ if not self.lfr_splice_cache:
+ for i in range(batch_size):
+ self.lfr_splice_cache.append(np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0))
+
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
+ frame_from_waveforms = int(
+ (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
+ minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(feats, feats_lengths, is_final)
+ if self.lfr_m == 1:
+ self.reserve_waveforms = None
+ else:
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
+ # print('frame_frame: ' + str(frame_from_waveforms))
+ self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
+ sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
+ self.waveforms = self.waveforms[:, :sample_length]
+ else:
+ # update self.reserve_waveforms and self.lfr_splice_cache
+ self.reserve_waveforms = self.waveforms[:,
+ :-(self.frame_sample_length - self.frame_shift_sample_length)]
+ for i in range(batch_size):
+ self.lfr_splice_cache[i] = np.concatenate((self.lfr_splice_cache[i], feats[i]), axis=0)
+ return np.empty(0, dtype=np.float32), feats_lengths
+ else:
+ if is_final:
+ self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
+ feats = np.stack(self.lfr_splice_cache)
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
+ if is_final:
+ self.cache_reset()
+ return feats, feats_lengths
+
+ def get_waveforms(self):
+ return self.waveforms
+
+ def cache_reset(self):
+ self.fbank_fn = knf.OnlineFbank(self.opts)
+ self.reserve_waveforms = None
+ self.input_cache = None
+ self.lfr_splice_cache = []
+
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
@@ -188,4 +370,4 @@
return feat, feat_len
if __name__ == '__main__':
- test()
\ No newline at end of file
+ test()
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index 2edde11..78c3f0d 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -24,21 +24,11 @@
):
check_argument_types()
- # self.token_list = self.load_token(token_path)
self.token_list = token_list
self.unk_symbol = token_list[-1]
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
+ self.unk_id = self.token2id[self.unk_symbol]
- # @staticmethod
- # def load_token(file_path: Union[Path, str]) -> List:
- # if not Path(file_path).exists():
- # raise TokenIDConverterError(f'The {file_path} does not exist.')
- #
- # with open(str(file_path), 'rb') as f:
- # token_list = pickle.load(f)
- #
- # if len(token_list) != len(set(token_list)):
- # raise TokenIDConverterError('The Token exists duplicated symbol.')
- # return token_list
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@@ -51,13 +41,8 @@
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
- token2id = {v: i for i, v in enumerate(self.token_list)}
- if self.unk_symbol not in token2id:
- raise TokenIDConverterError(
- f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
- )
- unk_id = token2id[self.unk_symbol]
- return [token2id.get(i, unk_id) for i in tokens]
+
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
@@ -188,7 +173,7 @@
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
- return self.session.run(None, input_dict)
+ return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
@@ -215,6 +200,38 @@
if not model_path.is_file():
raise FileExistsError(f'{model_path} is not a file.')
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
+
+def code_mix_split_words(text: str):
+ words = []
+ segs = text.split()
+ for seg in segs:
+ # There is no space in seg.
+ current_word = ""
+ for c in seg:
+ if len(c.encode()) == 1:
+ # This is an ASCII char.
+ current_word += c
+ else:
+ # This is a Chinese char.
+ if len(current_word) > 0:
+ words.append(current_word)
+ current_word = ""
+ words.append(c)
+ if len(current_word) > 0:
+ words.append(current_word)
+ return words
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
@@ -226,7 +243,7 @@
@functools.lru_cache()
-def get_logger(name='rapdi_paraformer'):
+def get_logger(name='funasr_onnx'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
new file mode 100644
index 0000000..ab8f041
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -0,0 +1,280 @@
+# -*- encoding: utf-8 -*-
+
+import os.path
+from pathlib import Path
+from typing import List, Union, Tuple
+
+import copy
+import librosa
+import numpy as np
+
+from .utils.utils import (ONNXRuntimeError,
+ OrtInferSession, get_logger,
+ read_yaml)
+from .utils.frontend import WavFrontend, WavFrontendOnline
+from .utils.e2e_vad import E2EVadModel
+
+logging = get_logger()
+
+
+class Fsmn_vad():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ max_end_sil: int = None,
+ ):
+
+ if not Path(model_dir).exists():
+ raise FileNotFoundError(f'{model_dir} does not exist.')
+
+ model_file = os.path.join(model_dir, 'model.onnx')
+ if quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+ config_file = os.path.join(model_dir, 'vad.yaml')
+ cmvn_file = os.path.join(model_dir, 'vad.mvn')
+ config = read_yaml(config_file)
+
+ self.frontend = WavFrontend(
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
+ )
+ self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.batch_size = batch_size
+ self.vad_scorer = E2EVadModel(config["vad_post_conf"])
+ self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+ self.encoder_conf = config["encoder_conf"]
+
+ def prepare_cache(self, in_cache: list = []):
+ if len(in_cache) > 0:
+ return in_cache
+ fsmn_layers = self.encoder_conf["fsmn_layers"]
+ proj_dim = self.encoder_conf["proj_dim"]
+ lorder = self.encoder_conf["lorder"]
+ for i in range(fsmn_layers):
+ cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
+ in_cache.append(cache)
+ return in_cache
+
+
+ def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
+ waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
+ waveform_nums = len(waveform_list)
+ is_final = kwargs.get('kwargs', False)
+
+ segments = [[]] * self.batch_size
+ for beg_idx in range(0, waveform_nums, self.batch_size):
+
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ waveform = waveform_list[beg_idx:end_idx]
+ feats, feats_len = self.extract_feat(waveform)
+ waveform = np.array(waveform)
+ param_dict = kwargs.get('param_dict', dict())
+ in_cache = param_dict.get('in_cache', list())
+ in_cache = self.prepare_cache(in_cache)
+ try:
+ t_offset = 0
+ step = int(min(feats_len.max(), 6000))
+ for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
+ if t_offset + step >= feats_len - 1:
+ step = feats_len - t_offset
+ is_final = True
+ else:
+ is_final = False
+ feats_package = feats[:, t_offset:int(t_offset + step), :]
+ waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
+
+ inputs = [feats_package]
+ # inputs = [feats]
+ inputs.extend(in_cache)
+ scores, out_caches = self.infer(inputs)
+ in_cache = out_caches
+ segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False)
+ # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
+
+ if segments_part:
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
+
+ except ONNXRuntimeError:
+ # logging.warning(traceback.format_exc())
+ logging.warning("input wav is silence or noise")
+ segments = ''
+
+ return segments
+
+ def load_data(self,
+ wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(
+ f'The type of {wav_content} is not in [str, np.ndarray, list]')
+
+ def extract_feat(self,
+ waveform_list: List[np.ndarray]
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ feats, feats_len = [], []
+ for waveform in waveform_list:
+ speech, _ = self.frontend.fbank(waveform)
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
+ feats.append(feat)
+ feats_len.append(feat_len)
+
+ feats = self.pad_feats(feats, np.max(feats_len))
+ feats_len = np.array(feats_len).astype(np.int32)
+ return feats, feats_len
+
+ @staticmethod
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
+ return np.pad(feat, pad_width, 'constant', constant_values=0)
+
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+ feats = np.array(feat_res).astype(np.float32)
+ return feats
+
+ def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
+
+ outputs = self.ort_infer(feats)
+ scores, out_caches = outputs[0], outputs[1:]
+ return scores, out_caches
+
+
+class Fsmn_vad_online():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ max_end_sil: int = None,
+ ):
+
+ if not Path(model_dir).exists():
+ raise FileNotFoundError(f'{model_dir} does not exist.')
+
+ model_file = os.path.join(model_dir, 'model.onnx')
+ if quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+ config_file = os.path.join(model_dir, 'vad.yaml')
+ cmvn_file = os.path.join(model_dir, 'vad.mvn')
+ config = read_yaml(config_file)
+
+ self.frontend = WavFrontendOnline(
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
+ )
+ self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.batch_size = batch_size
+ self.vad_scorer = E2EVadModel(config["vad_post_conf"])
+ self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+ self.encoder_conf = config["encoder_conf"]
+
+ def prepare_cache(self, in_cache: list = []):
+ if len(in_cache) > 0:
+ return in_cache
+ fsmn_layers = self.encoder_conf["fsmn_layers"]
+ proj_dim = self.encoder_conf["proj_dim"]
+ lorder = self.encoder_conf["lorder"]
+ for i in range(fsmn_layers):
+ cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
+ in_cache.append(cache)
+ return in_cache
+
+ def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
+ waveforms = np.expand_dims(audio_in, axis=0)
+
+ param_dict = kwargs.get('param_dict', dict())
+ is_final = param_dict.get('is_final', False)
+ feats, feats_len = self.extract_feat(waveforms, is_final)
+ segments = []
+ if feats.size != 0:
+ in_cache = param_dict.get('in_cache', list())
+ in_cache = self.prepare_cache(in_cache)
+ try:
+ inputs = [feats]
+ inputs.extend(in_cache)
+ scores, out_caches = self.infer(inputs)
+ param_dict['in_cache'] = out_caches
+ waveforms = self.frontend.get_waveforms()
+ segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
+ online=True)
+
+
+ except ONNXRuntimeError:
+ # logging.warning(traceback.format_exc())
+ logging.warning("input wav is silence or noise")
+ segments = []
+ return segments
+
+ def load_data(self,
+ wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(
+ f'The type of {wav_content} is not in [str, np.ndarray, list]')
+
+ def extract_feat(self,
+ waveforms: np.ndarray, is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
+ for idx, waveform in enumerate(waveforms):
+ waveforms_lens[idx] = waveform.shape[-1]
+
+ feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
+ # feats.append(feat)
+ # feats_len.append(feat_len)
+
+ # feats = self.pad_feats(feats, np.max(feats_len))
+ # feats_len = np.array(feats_len).astype(np.int32)
+ return feats.astype(np.float32), feats_len.astype(np.int32)
+
+ @staticmethod
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
+ return np.pad(feat, pad_width, 'constant', constant_values=0)
+
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+ feats = np.array(feat_res).astype(np.float32)
+ return feats
+
+ def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
+
+ outputs = self.ort_infer(feats)
+ scores, out_caches = outputs[0], outputs[1:]
+ return scores, out_caches
+
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 3b9ed3b..975b14d 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,14 +13,14 @@
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.0.2'
+VERSION_NUM = '0.0.5'
setuptools.setup(
name=MODULE_NAME,
version=VERSION_NUM,
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab, Alibaba Group, China",
+ author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license='MIT',
diff --git a/funasr/runtime/python/utils/test_rtf_gpu.py b/funasr/runtime/python/utils/test_rtf_gpu.py
new file mode 100644
index 0000000..84cd2c7
--- /dev/null
+++ b/funasr/runtime/python/utils/test_rtf_gpu.py
@@ -0,0 +1,58 @@
+
+import time
+import sys
+import librosa
+from funasr.utils.types import str2bool
+
+import argparse
+parser = argparse.ArgumentParser()
+parser.add_argument('--model_dir', type=str, required=True)
+parser.add_argument('--backend', type=str, default='onnx', help='["onnx", "torch"]')
+parser.add_argument('--wav_file', type=str, default=None, help='amp fallback number')
+parser.add_argument('--quantize', type=str2bool, default=False, help='quantized model')
+parser.add_argument('--intra_op_num_threads', type=int, default=1, help='intra_op_num_threads for onnx')
+parser.add_argument('--batch_size', type=int, default=1, help='batch_size for onnx')
+args = parser.parse_args()
+
+
+from funasr.runtime.python.libtorch.funasr_torch import Paraformer
+if args.backend == "onnx":
+ from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+
+model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
+
+wav_file_f = open(args.wav_file, 'r')
+wav_files = wav_file_f.readlines()
+
+# warm-up
+total = 0.0
+num = 30
+wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
+for i in range(num):
+ beg_time = time.time()
+ result = model(wav_path)
+ end_time = time.time()
+ duration = end_time-beg_time
+ total += duration
+ print(result)
+ print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
+
+# infer time
+wav_path = []
+beg_time = time.time()
+for i, wav_path_i in enumerate(wav_files):
+ wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+ wav_path += [wav_path_i]
+result = model(wav_path)
+end_time = time.time()
+duration = (end_time-beg_time)*1000
+print("total_time_comput_ms: {}".format(int(duration)))
+
+duration_time = 0.0
+for i, wav_path_i in enumerate(wav_files):
+ wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+ waveform, _ = librosa.load(wav_path, sr=16000)
+ duration_time += len(waveform)/16.0
+print("total_time_wav_ms: {}".format(int(duration_time)))
+
+print("total_rtf: {:.5}".format(duration/duration_time))
\ No newline at end of file
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 14987f1..777513e 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -464,6 +464,12 @@
default=sys.maxsize,
help="The maximum number update step to train",
)
+ parser.add_argument(
+ "--batch_interval",
+ type=int,
+ default=10000,
+ help="The batch interval for saving model.",
+ )
group.add_argument(
"--patience",
type=int_or_none,
@@ -1576,13 +1582,21 @@
) -> AbsIterFactory:
assert check_argument_types()
+ if hasattr(args, "frontend_conf"):
+ if args.frontend_conf is not None and "fs" in args.frontend_conf:
+ dest_sample_rate = args.frontend_conf["fs"]
+ else:
+ dest_sample_rate = 16000
+ else:
+ dest_sample_rate = 16000
+
dataset = ESPnetDataset(
iter_options.data_path_and_name_and_type,
float_dtype=args.train_dtype,
preprocess=iter_options.preprocess_fn,
max_cache_size=iter_options.max_cache_size,
max_cache_fd=iter_options.max_cache_fd,
- dest_sample_rate=args.frontend_conf["fs"],
+ dest_sample_rate=dest_sample_rate,
)
cls.check_task_requirements(
dataset, args.allow_variable_data_keys, train=iter_options.train
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 6e0f16a..e151473 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -412,12 +412,6 @@
default="13_15",
help="The range of noise decibel level.",
)
- parser.add_argument(
- "--batch_interval",
- type=int,
- default=10000,
- help="The batch interval for saving model.",
- )
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 096a5c8..45e4ce7 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -1,3 +1,11 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+https://arxiv.org/abs/2211.10243
+TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+https://arxiv.org/abs/2303.05397
+"""
+
import argparse
import logging
import os
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 608c1d3..80d66d5 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -15,7 +15,7 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.lm.abs_model import AbsLM
-from funasr.lm.espnet_model import ESPnetLanguageModel
+from funasr.lm.abs_model import LanguageModel
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
@@ -83,7 +83,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetLanguageModel),
+ default=get_default_kwargs(LanguageModel),
help="The keyword arguments for model class.",
)
@@ -178,7 +178,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
+ def build_model(cls, args: argparse.Namespace) -> LanguageModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@@ -201,7 +201,7 @@
# 2. Build ESPnetModel
# Assume the last-id is sos_and_eos
- model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
+ model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
# 3. Initialize
if args.init is not None:
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index ea1e102..0170f28 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,10 +14,10 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.punctuation.abs_model import AbsPunctuation
-from funasr.punctuation.espnet_model import ESPnetPunctuationModel
-from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
-from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.train.abs_model import AbsPunctuation
+from funasr.train.abs_model import PunctuationModel
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@@ -79,7 +79,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetPunctuationModel),
+ default=get_default_kwargs(PunctuationModel),
help="The keyword arguments for model class.",
)
@@ -183,7 +183,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
+ def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
@@ -218,7 +218,7 @@
# Assume the last-id is sos_and_eos
if "punc_weight" in args.model_conf:
args.model_conf.pop("punc_weight")
- model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+ model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
# FIXME(kamo): Should be done in model?
# 3. Initialize
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index bef5dc5..9710447 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -1,3 +1,7 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+"""
+
import argparse
import logging
import os
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index 22a5cb3..d07acf1 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -40,7 +40,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
@@ -81,6 +81,7 @@
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
+ wav_frontend_online=WavFrontendOnline,
),
type_check=AbsFrontend,
default="default",
@@ -291,7 +292,24 @@
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
- model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
@@ -302,6 +320,7 @@
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
device: str = "cpu",
+ cmvn_file: Union[Path, str] = None,
):
"""Build model from the files.
@@ -325,6 +344,8 @@
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
+ #if cmvn_file is not None:
+ args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
model.to(device)
diff --git a/funasr/punctuation/espnet_model.py b/funasr/train/abs_model.py
similarity index 86%
rename from funasr/punctuation/espnet_model.py
rename to funasr/train/abs_model.py
index 7266b38..1c7ff3d 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/train/abs_model.py
@@ -1,3 +1,7 @@
+from abc import ABC
+from abc import abstractmethod
+
+
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -7,13 +11,34 @@
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
-from funasr.punctuation.abs_model import AbsPunctuation
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-class ESPnetPunctuationModel(AbsESPnetModel):
+class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
+ """The abstract class
+
+ To share the loss calculation way among different models,
+ We uses delegate pattern here:
+ The instance of this class should be passed to "LanguageModel"
+
+ This "model" is one of mediator objects for "Task" class.
+
+ """
+
+ @abstractmethod
+ def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def with_vad(self) -> bool:
+ raise NotImplementedError
+
+
+class PunctuationModel(AbsESPnetModel):
+
def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
@@ -21,12 +46,12 @@
self.punc_weight = torch.Tensor(punc_weight)
self.sos = 1
self.eos = 2
-
+
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
- #if self.punc_model.with_vad():
+ # if self.punc_model.with_vad():
# print("This is a vad puncuation model.")
-
+
def nll(
self,
text: torch.Tensor,
@@ -54,7 +79,7 @@
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
-
+
if self.punc_model.with_vad():
# Should be VadRealtimeTransformer
assert vad_indexes is not None
@@ -62,7 +87,7 @@
else:
# Should be TargetDelayTransformer,
y, _ = self.punc_model(text, text_lengths)
-
+
# Calc negative log likelihood
# nll: (BxL,)
if self.training == False:
@@ -75,7 +100,8 @@
return nll, text_lengths
else:
self.punc_weight = self.punc_weight.to(punc.device)
- nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
+ nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+ ignore_index=self.ignore_id)
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +113,7 @@
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, text_lengths
-
+
def batchify_nll(self,
text: torch.Tensor,
punc: torch.Tensor,
@@ -113,7 +139,7 @@
nlls = []
x_lengths = []
max_length = text_lengths.max()
-
+
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +158,7 @@
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
-
+
def forward(
self,
text: torch.Tensor,
@@ -146,15 +172,15 @@
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
-
+
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
-
+
def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
return {}
-
+
def inference(self,
text: torch.Tensor,
text_lengths: torch.Tensor,
diff --git a/funasr/utils/compute_wer.py b/funasr/utils/compute_wer.py
index 349a3f6..26a9f49 100755
--- a/funasr/utils/compute_wer.py
+++ b/funasr/utils/compute_wer.py
@@ -45,8 +45,8 @@
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
- cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
- cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+ cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+ cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
diff --git a/funasr/version.txt b/funasr/version.txt
index d15723f..267577d 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.3.2
+0.4.1
diff --git a/setup.py b/setup.py
index c3eed88..6deaf9c 100644
--- a/setup.py
+++ b/setup.py
@@ -123,7 +123,7 @@
name="funasr",
version=version,
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab, Alibaba Group, China",
+ author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index b3c5a24..2f2f11d 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -43,6 +43,7 @@
rec_result = inference_pipeline(
audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "姣忎竴澶╅兘瑕佸揩涔愬枖"
def test_paraformer(self):
inference_pipeline = pipeline(
@@ -51,6 +52,7 @@
rec_result = inference_pipeline(
audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "姣忎竴澶╅兘瑕佸揩涔愬枖"
class TestMfccaInferencePipelines(unittest.TestCase):
diff --git a/tests/test_punctuation_pipeline.py b/tests/test_punctuation_pipeline.py
index 52be9bb..e582042 100644
--- a/tests/test_punctuation_pipeline.py
+++ b/tests/test_punctuation_pipeline.py
@@ -26,16 +26,14 @@
inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
- model_revision="v1.0.0",
)
inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
vads = inputs.split("|")
- cache_out = []
rec_result_all = "outputs:"
+ param_dict = {"cache": []}
for vad in vads:
- rec_result = inference_pipeline(text_in=vad, cache=cache_out)
- cache_out = rec_result['cache']
- rec_result_all += rec_result['text']
+ rec_result = inference_pipeline(text_in=vad, param_dict=param_dict)
+ rec_result_all += rec_result["text"]
logger.info("punctuation inference result: {0}".format(rec_result_all))
--
Gitblit v1.9.1