From 9817785c66a13caa681a8e9e272f2ae949233542 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期二, 18 四月 2023 19:28:39 +0800
Subject: [PATCH] Merge pull request #380 from alibaba-damo-academy/main
---
funasr/runtime/onnxruntime/readme.md | 121
funasr/bin/asr_train_transducer.py | 46
funasr/runtime/grpc/Readme.md | 198
funasr/bin/lm_inference_launch.py | 3
egs/aishell/rnnt/local/aishell_data_prep.sh | 66
funasr/runtime/onnxruntime/include/libfunasrapi.h | 4
docs/installation.md | 48
egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml | 8
docs/runtime/export.md | 1
funasr/models/encoder/opennmt_encoders/conv_encoder.py | 2
funasr/runtime/onnxruntime/src/precomp.h | 1
tests/test_asr_inference_pipeline.py | 2
docs/huggingface_models.md | 94
funasr/models/encoder/resnet34_encoder.py | 12
funasr/bin/sv_inference_launch.py | 3
funasr/models/encoder/conformer_encoder.py | 634 ++++
funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py | 35
funasr/models/encoder/sanm_encoder.py | 12
funasr/tasks/diar.py | 8
funasr/bin/punc_inference_launch.py | 3
funasr/tasks/abs_task.py | 4
egs/aishell/rnnt/utils | 1
docs/recipe/sv_recipe.md | 2
funasr/export/models/CT_Transformer.py | 12
funasr/runtime/python/onnxruntime/demo_vad_online.py | 28
funasr/models/e2e_vad.py | 29
funasr/modules/beam_search/beam_search_transducer.py | 704 ++++
funasr/runtime/onnxruntime/src/paraformer_onnx.h | 12
funasr/models/decoder/rnnt_decoder.py | 258 +
egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py | 6
funasr/models/e2e_diar_sond.py | 8
funasr/models/e2e_tp.py | 2
funasr/export/models/e2e_asr_paraformer.py | 4
funasr/models/joint_net/joint_network.py | 61
funasr/bin/asr_inference_rnnt.py | 1185 +++----
docs/index.rst | 55
funasr/models/decoder/sanm_decoder.py | 4
egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml | 5
docs/recipe/asr_recipe.md | 2
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 80
funasr/export/models/__init__.py | 4
funasr/models/decoder/transformer_decoder.py | 2
.github/workflows/main.yml | 1
funasr/bin/asr_inference_paraformer_streaming.py | 18
funasr/datasets/dataset.py | 2
funasr/runtime/onnxruntime/CMakeLists.txt | 19
funasr/models/e2e_asr_transducer.py | 1013 ++++++
docs/recipe/vad_recipe.md | 2
docs/runtime/grpc_cpp.md | 1
funasr/bin/vad_inference_launch.py | 3
funasr/models/e2e_asr_paraformer.py | 298 +
funasr/export/README.md | 3
funasr/bin/diar_inference_launch.py | 3
docs/runtime/onnxruntime_python.md | 1
funasr/runtime/onnxruntime/include/Audio.h | 17
funasr/runtime/python/onnxruntime/README.md | 74
docs/runtime/grpc_python.md | 1
docs/modescope_pipeline/sv_pipeline.md | 20
docs/modescope_pipeline/vad_pipeline.md | 20
funasr/runtime/onnxruntime/src/CMakeLists.txt | 1
docs/papers.md | 32
funasr/models/e2e_sv.py | 4
funasr/runtime/python/grpc/Readme.md | 4
funasr/train/trainer.py | 18
setup.py | 2
docs/modescope_pipeline/lm_pipeline.md | 14
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp | 26
funasr/runtime/onnxruntime/src/resample.h | 137
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 12
funasr/runtime/python/libtorch/README.md | 83
docs/runtime/websocket_python.md | 1
funasr/runtime/onnxruntime/src/Audio.cpp | 262 +
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 6
funasr/bin/punctuation_infer_vadrealtime.py | 2
funasr/bin/tp_inference_launch.py | 3
docs/modescope_pipeline/tp_pipeline.md | 20
funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py | 6
funasr/runtime/grpc/paraformer_server.cc | 2
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py | 201 +
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py | 2
docs/modescope_pipeline/asr_pipeline.md | 20
README.md | 38
funasr/models/vad_realtime_transformer.py | 6
funasr/runtime/onnxruntime/src/Vocab.cpp | 15
docs/modelscope_models.md | 114
funasr/datasets/large_datasets/utils/tokenize.py | 8
funasr/modules/embedding.py | 100
funasr/version.txt | 2
funasr/modules/nets_utils.py | 195 +
.gitignore | 5
funasr/models/e2e_asr_mfcca.py | 6
egs/aishell/rnnt/path.sh | 5
docs/runtime/libtorch_python.md | 1
funasr/runtime/onnxruntime/src/resample.cc | 305 ++
egs/aishell/rnnt/run.sh | 247 +
docs/runtime/onnxruntime_cpp.md | 1
docs/recipe/punc_recipe.md | 2
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py | 6
funasr/models/target_delay_transformer.py | 6
funasr/modules/attention.py | 220 +
funasr/datasets/preprocessor.py | 7
funasr/modules/e2e_asr_common.py | 150 +
funasr/runtime/onnxruntime/src/libfunasrapi.cpp | 16
funasr/models/e2e_uni_asr.py | 2
docs/modescope_pipeline/modelscope_usages.md | 0
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py | 1
funasr/runtime/python/onnxruntime/demo_vad_offline.py | 11
docs/modescope_pipeline/quick_start.md | 139
funasr/runtime/python/onnxruntime/setup.py | 4
funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py | 184 +
funasr/tasks/sv.py | 4
funasr/models/decoder/contextual_decoder.py | 2
funasr/tasks/asr.py | 394 ++
funasr/runtime/python/libtorch/setup.py | 6
funasr/modules/subsampling.py | 202 +
funasr/models/frontend/wav_frontend.py | 12
funasr/modules/streaming_utils/chunk_utilis.py | 2
/dev/null | 30
docs/recipe/lm_recipe.md | 2
funasr/models/predictor/cif.py | 12
tests/test_punctuation_pipeline.py | 8
docs/modescope_pipeline/punc_pipeline.md | 20
egs/aishell/rnnt/README.md | 18
funasr/modules/repeat.py | 91
funasr/bin/asr_inference_launch.py | 46
docs/images/logo.png | 0
126 files changed, 7,432 insertions(+), 1,336 deletions(-)
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 8cb22cb..2497ac2 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -6,6 +6,7 @@
push:
branches:
- dev_wjm
+ - main
- dev_lyh
jobs:
diff --git a/.gitignore b/.gitignore
index 603f712..13d2fff 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,7 @@
test_local/
RapidASR
export/*
-*.pyc
\ No newline at end of file
+*.pyc
+.eggs
+MaaS-lib
+.gitignore
\ No newline at end of file
diff --git a/README.md b/README.md
index 64231ca..03156f3 100644
--- a/README.md
+++ b/README.md
@@ -1,18 +1,22 @@
[//]: # (<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>)
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
+<p align="left">
+ <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-brightgreen.svg"></a>
+ <a href=""><img src="https://img.shields.io/badge/Python->=3.7,<=3.10-aff.svg"></a>
+ <a href=""><img src="https://img.shields.io/badge/Pytorch-%3E%3D1.11-blue"></a>
+</p>
<strong>FunASR</strong> hopes to build a bridge between academic research and industrial applications on speech recognition. By supporting the training & finetuning of the industrial-grade speech recognition model released on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition), researchers and developers can conduct research and production of speech recognition models more conveniently, and promote the development of speech recognition ecology. ASR for Fun锛�
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
| [**Highlights**](#highlights)
| [**Installation**](#installation)
-| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
-| [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
+| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md)
| [**Contact**](#contact)
[**M2MET2.0 Guidence_CN**](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html)
@@ -25,7 +29,7 @@
For the release notes, please ref to [news](https://github.com/alibaba-damo-academy/FunASR/releases)
## Highlights
-- Many types of typical models are supported, e.g., [Tranformer](https://arxiv.org/abs/1706.03762), [Conformer](https://arxiv.org/abs/2005.08100), [Paraformer](https://arxiv.org/abs/2206.08317).
+- FunASR supports speech recognition(ASR), Multi-talker ASR, Voice Activity Detection(VAD), Punctuation Restoration, Language Models, Speaker Verification and Speaker diarization.
- We have released large number of academic and industrial pretrained models on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition)
- The pretrained model [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) obtains the best performance on many tasks in [SpeechIO leaderboard](https://github.com/SpeechColab/Leaderboard)
- FunASR supplies a easy-to-use pipeline to finetune pretrained models from [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition)
@@ -33,15 +37,37 @@
## Installation
+Install from pip
+```shell
+pip install -U funasr
+# For the users in China, you could install with the command:
+# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
+Or install from source code
+
+
``` sh
-pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+
```
+If you want to use the pretrained models in ModelScope, you should install the modelscope:
+
+```shell
+pip install -U modelscope
+# For the users in China, you could install with the command:
+# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
-## Usage
-For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
+[//]: # ()
+[//]: # (## Usage)
+
+[//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)))
## Contact
diff --git a/docs/huggingface_models.md b/docs/huggingface_models.md
new file mode 100644
index 0000000..61754eb
--- /dev/null
+++ b/docs/huggingface_models.md
@@ -0,0 +1,94 @@
+# Pretrained Models on Huggingface
+
+## Model License
+- Apache License 2.0
+
+## Model Zoo
+Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
+
+### Speech Recognition Models
+#### Paraformer Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
+| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
+| [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
+| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
+| [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
+| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
+| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+
+
+#### UniASR Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
+| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
+| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
+
+#### Conformer Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
+| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
+
+
+#### RNN-T Models
+
+### Multi-talker Speech Recognition Models
+
+#### MFCCA Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:--------:|:------------------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary) | CN | AliMeeting銆丄ISHELL-4銆丼imudata (917hours) | 4950 | 45M | Offline | Duration of input wav <= 20s, channel of input wav <= 8 channel |
+
+
+
+### Voice Activity Detection Models
+
+| Model Name | Training Data | Parameters | Sampling Rate | Notes |
+|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
+
+### Punctuation Restoration Models
+
+| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
+
+### Language Models
+
+| Model Name | Training Data | Parameters | Vocab Size | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
+| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
+
+### Speaker Verification Models
+
+| Model Name | Training Data | Parameters | Number Speaker | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (1,200 hours) | 17.5M | 3465 | Xvector, speaker verification, Chinese |
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (60 hours) | 61M | 6135 | Xvector, speaker verification, English |
+
+### Speaker diarization Models
+
+| Model Name | Training Data | Parameters | Notes |
+|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (120 hours) | 40.5M | Speaker diarization, profiles and records, Chinese |
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (60 hours) | 12M | Speaker diarization, profiles and records, English |
+
+### Timestamp Prediction Models
+
+| Model Name | Language | Training Data | Parameters | Notes |
+|:--------------------------------------------------------------------------------------------------:|:--------------:|:-------------------:|:----------:|:------|
+| [TP-Aligner](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) | CN | Alibaba Speech Data (50000hours) | 37.8M | Timestamp prediction, Mandarin, middle size |
diff --git a/docs/images/logo.png b/docs/images/logo.png
new file mode 100644
index 0000000..7375de3
--- /dev/null
+++ b/docs/images/logo.png
Binary files differ
diff --git a/docs/index.rst b/docs/index.rst
index d7fc96b..e5b9ab8 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -11,19 +11,64 @@
.. toctree::
:maxdepth: 1
- :caption: Tutorial:
+ :caption: Installation
./installation.md
- ./papers.md
- ./get_started.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Recipe
+
+ ./recipe/asr_recipe.md
+ ./recipe/sv_recipe.md
+ ./recipe/punc_recipe.md
+ ./recipe/vad_recipe.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Custom Your Model
+
./build_task.md
.. toctree::
:maxdepth: 1
- :caption: ModelScope:
+ :caption: Model Zoo
./modelscope_models.md
- ./modelscope_usages.md
+ ./huggingface_models.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: ModelScope pipeline
+
+ ./modescope_pipeline/quick_start.md
+ ./modescope_pipeline/asr_pipeline.md
+ ./modescope_pipeline/vad_pipeline.md
+ ./modescope_pipeline/punc_pipeline.md
+ ./modescope_pipeline/tp_pipeline.md
+ ./modescope_pipeline/sv_pipeline.md
+ ./modescope_pipeline/lm_pipeline.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Runtime
+
+ ./runtime/export.md
+ ./runtime/onnxruntime_python.md
+ ./runtime/onnxruntime_cpp.md
+ ./runtime/libtorch_python.md
+ ./runtime/grpc_python.md
+ ./runtime/grpc_cpp.md
+ ./runtime/websocket_python.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Papers
+
+ ./papers.md
+
+
+
Indices and tables
==================
diff --git a/docs/installation.md b/docs/installation.md
index fb26913..04d8a84 100755
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,7 +1,13 @@
-# Installation
-FunASR is easy to install. The detailed installation steps are as follows:
+<p align="left">
+ <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-brightgreen.svg"></a>
+ <a href=""><img src="https://img.shields.io/badge/Python->=3.7,<=3.10-aff.svg"></a>
+ <a href=""><img src="https://img.shields.io/badge/Pytorch-%3E%3D1.11-blue"></a>
+</p>
-- Install Conda and create virtual environment:
+## Installation
+
+### Install Conda (Optional):
+
```sh
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
sh Miniconda3-latest-Linux-x86_64.sh
@@ -10,26 +16,38 @@
conda activate funasr
```
-- Install Pytorch (version >= 1.7.0):
+### Install Pytorch (version >= 1.11.0):
+
```sh
pip install torch torchaudio
```
-For more versions, please see [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally)
+For more details about torch, please see [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally)
-- Install ModelScope
+### Install funasr
-For users in China, you can configure the following mirror source to speed up the downloading:
-``` sh
-pip config set global.index-url https://mirror.sjtu.edu.cn/pypi/web/simple
-```
-Install or update ModelScope
-```sh
-pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+#### Install from pip
+
+```shell
+pip install -U funasr
+# For the users in China, you could install with the command:
+# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
-- Clone the repo and install other packages
+#### Or install from source code
+
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install --editable ./
+pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
+### Install modelscope (Optional)
+If you want to use the pretrained models in ModelScope, you should install the modelscope:
+
+```shell
+pip install -U modelscope
+# For the users in China, you could install with the command:
+# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
\ No newline at end of file
diff --git a/docs/modelscope_models.md b/docs/modelscope_models.md
index 277d8e9..b35d625 100644
--- a/docs/modelscope_models.md
+++ b/docs/modelscope_models.md
@@ -1,4 +1,4 @@
-# Pretrained models on ModelScope
+# Pretrained Models on ModelScope
## Model License
- Apache License 2.0
@@ -6,29 +6,89 @@
## Model Zoo
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
-| Datasets | Hours | Model | Online/Offline | Language | Framework | Checkpoint |
-|:-----:|:-----:|:--------------:|:--------------:| :---: | :---: | --- |
-| Alibaba Speech Data | 60000 | Paraformer | Offline | CN | Pytorch |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) |
-| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
-| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
-| Alibaba Speech Data | 50000 | Paraformer | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online](http://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Offline | CN | Tensorflow |[speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Online | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 50000 | UniASR | Offline | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 20000 | UniASR | Online | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 20000 | UniASR | Offline | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
-| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
-| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
-| AISHELL-1 | 178 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) |
-| AISHELL-2 | 1000 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell2-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell2-pytorch/summary) |
-| AISHELL-1 | 178 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
-| AISHELL-2 | 1000 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
-| AISHELL-1 | 178 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
-| AISHELL-2 | 1000 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
+### Speech Recognition Models
+#### Paraformer Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
+| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
+| [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
+| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
+| [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
+| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
+| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
+| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
+
+
+#### UniASR Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
+| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
+| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
+| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
+
+#### Conformer Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
+| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
+
+
+#### RNN-T Models
+
+### Multi-talker Speech Recognition Models
+
+#### MFCCA Models
+
+| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:--------:|:------------------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
+| [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary) | CN | AliMeeting銆丄ISHELL-4銆丼imudata (917hours) | 4950 | 45M | Offline | Duration of input wav <= 20s, channel of input wav <= 8 channel |
+
+
+
+### Voice Activity Detection Models
+
+| Model Name | Training Data | Parameters | Sampling Rate | Notes |
+|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
+| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
+
+### Punctuation Restoration Models
+
+| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
+|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
+| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
+
+### Language Models
+
+| Model Name | Training Data | Parameters | Vocab Size | Notes |
+|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
+| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
+
+### Speaker Verification Models
+
+| Model Name | Training Data | Parameters | Number Speaker | Notes |
+|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (1,200 hours) | 17.5M | 3465 | Xvector, speaker verification, Chinese |
+| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (60 hours) | 61M | 6135 | Xvector, speaker verification, English |
+
+### Speaker diarization Models
+
+| Model Name | Training Data | Parameters | Notes |
+|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (120 hours) | 40.5M | Speaker diarization, profiles and records, Chinese |
+| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (60 hours) | 12M | Speaker diarization, profiles and records, English |
+
+### Timestamp Prediction Models
+
+| Model Name | Language | Training Data | Parameters | Notes |
+|:--------------------------------------------------------------------------------------------------:|:--------------:|:-------------------:|:----------:|:------|
+| [TP-Aligner](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) | CN | Alibaba Speech Data (50000hours) | 37.8M | Timestamp prediction, Mandarin, middle size |
diff --git a/docs/modescope_pipeline/asr_pipeline.md b/docs/modescope_pipeline/asr_pipeline.md
new file mode 100644
index 0000000..3dc0bd0
--- /dev/null
+++ b/docs/modescope_pipeline/asr_pipeline.md
@@ -0,0 +1,20 @@
+# Speech Recognition
+
+## Inference
+
+### Quick start
+
+#### Inference with you data
+
+#### Inference with multi-threads on CPU
+
+#### Inference with multi GPU
+
+## Finetune with pipeline
+
+### Quick start
+
+### Finetune with your data
+
+## Inference with your finetuned model
+
diff --git a/docs/modescope_pipeline/lm_pipeline.md b/docs/modescope_pipeline/lm_pipeline.md
new file mode 100644
index 0000000..cb81871
--- /dev/null
+++ b/docs/modescope_pipeline/lm_pipeline.md
@@ -0,0 +1,14 @@
+# Speech Recognition
+
+## Inference with pipeline
+### Quick start
+#### Inference with you data
+#### Inference with multi-threads on CPU
+#### Inference with multi GPU
+
+## Finetune with pipeline
+### Quick start
+### Finetune with your data
+
+## Inference with your finetuned model
+
diff --git a/docs/modelscope_usages.md b/docs/modescope_pipeline/modelscope_usages.md
similarity index 100%
rename from docs/modelscope_usages.md
rename to docs/modescope_pipeline/modelscope_usages.md
diff --git a/docs/modescope_pipeline/punc_pipeline.md b/docs/modescope_pipeline/punc_pipeline.md
new file mode 100644
index 0000000..67ee695
--- /dev/null
+++ b/docs/modescope_pipeline/punc_pipeline.md
@@ -0,0 +1,20 @@
+# Punctuation Restoration
+
+## Inference with pipeline
+
+### Quick start
+
+#### Inference with you data
+
+#### Inference with multi-threads on CPU
+
+#### Inference with multi GPU
+
+## Finetune with pipeline
+
+### Quick start
+
+### Finetune with your data
+
+## Inference with your finetuned model
+
diff --git a/docs/modescope_pipeline/quick_start.md b/docs/modescope_pipeline/quick_start.md
new file mode 100644
index 0000000..ab46a7c
--- /dev/null
+++ b/docs/modescope_pipeline/quick_start.md
@@ -0,0 +1,139 @@
+# Quick Start
+
+## Inference with pipeline
+
+### Speech Recognition
+#### Paraformer model
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+)
+
+rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+print(rec_result)
+```
+
+### Voice Activity Detection
+#### FSMN-VAD
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+import logging
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+inference_pipeline = pipeline(
+ task=Tasks.voice_activity_detection,
+ model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
+ )
+
+segments_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
+print(segments_result)
+```
+
+### Punctuation Restoration
+#### CT_Transformer
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.punctuation,
+ model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
+ )
+
+rec_result = inference_pipeline(text_in='鎴戜滑閮芥槸鏈ㄥご浜轰笉浼氳璇濅笉浼氬姩')
+print(rec_result)
+```
+
+### Timestamp Prediction
+#### TP-Aligner
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.speech_timestamp,
+ model='damo/speech_timestamp_prediction-v1-16k-offline',
+ output_dir='./tmp')
+
+rec_result = inference_pipeline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
+ text_in='涓� 涓� 涓� 澶� 骞� 娲� 鍥� 瀹� 涓� 浠� 涔� 璺� 鍒� 瑗� 澶� 骞� 娲� 鏉� 浜� 鍛�',)
+print(rec_result)
+```
+
+### Speaker Verification
+#### X-vector
+```python
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import numpy as np
+
+inference_sv_pipline = pipeline(
+ task=Tasks.speaker_verification,
+ model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+)
+
+# embedding extract
+spk_embedding = inference_sv_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')["spk_embedding"]
+
+# speaker verification
+rec_result = inference_sv_pipline(audio_in=('https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav','https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
+print(rec_result["scores"][0])
+```
+
+## Finetune with pipeline
+### Speech Recognition
+#### Paraformer model
+
+finetune.py
+```python
+import os
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+from modelscope.msdatasets.audio.asr_dataset import ASRDataset
+
+def modelscope_finetune(params):
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
+ # dataset split ["train", "validation"]
+ ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
+ kwargs = dict(
+ model=params.model,
+ data_dir=ds_dict,
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
+ trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ from funasr.utils.modelscope_param import modelscope_args
+ params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ params.output_dir = "./checkpoint" # 妯″瀷淇濆瓨璺緞
+ params.data_path = "speech_asr_aishell1_trainsets" # 鏁版嵁璺緞锛屽彲浠ヤ负modelscope涓凡涓婁紶鏁版嵁锛屼篃鍙互鏄湰鍦版暟鎹�
+ params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+ params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 50 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+
+ modelscope_finetune(params)
+```
+
+```shell
+python finetune.py &> log.txt &
+```
+If you want finetune with multi-GPUs, you could:
+```shell
+CUDA_VISIBLE_DEVICES=1,2 python -m torch.distributed.launch --nproc_per_node 2 finetune.py > log.txt 2>&1
+```
+
diff --git a/docs/modescope_pipeline/sv_pipeline.md b/docs/modescope_pipeline/sv_pipeline.md
new file mode 100644
index 0000000..6ce8c6a
--- /dev/null
+++ b/docs/modescope_pipeline/sv_pipeline.md
@@ -0,0 +1,20 @@
+# Speaker Verification
+
+## Inference with pipeline
+
+### Quick start
+
+#### Inference with you data
+
+#### Inference with multi-threads on CPU
+
+#### Inference with multi GPU
+
+## Finetune with pipeline
+
+### Quick start
+
+### Finetune with your data
+
+## Inference with your finetuned model
+
diff --git a/docs/modescope_pipeline/tp_pipeline.md b/docs/modescope_pipeline/tp_pipeline.md
new file mode 100644
index 0000000..fad55e3
--- /dev/null
+++ b/docs/modescope_pipeline/tp_pipeline.md
@@ -0,0 +1,20 @@
+# Timestamp Prediction
+
+## Inference with pipeline
+
+### Quick start
+
+#### Inference with you data
+
+#### Inference with multi-threads on CPU
+
+#### Inference with multi GPU
+
+## Finetune with pipeline
+
+### Quick start
+
+### Finetune with your data
+
+## Inference with your finetuned model
+
diff --git a/docs/modescope_pipeline/vad_pipeline.md b/docs/modescope_pipeline/vad_pipeline.md
new file mode 100644
index 0000000..5dcbe59
--- /dev/null
+++ b/docs/modescope_pipeline/vad_pipeline.md
@@ -0,0 +1,20 @@
+# Voice Activity Detection
+
+## Inference with pipeline
+
+### Quick start
+
+#### Inference with you data
+
+#### Inference with multi-threads on CPU
+
+#### Inference with multi GPU
+
+## Finetune with pipeline
+
+### Quick start
+
+### Finetune with your data
+
+## Inference with your finetuned model
+
diff --git a/docs/papers.md b/docs/papers.md
index e9a83e4..33bf72f 100644
--- a/docs/papers.md
+++ b/docs/papers.md
@@ -1,4 +1,34 @@
# Papers
+FunASR have implemented the following paper code
+
+### Speech Recognition
+- [Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition](https://arxiv.org/abs/2206.08317), INTERSPEECH 2022.
- [Universal ASR: Unifying Streaming and Non-Streaming ASR Using a Single Encoder-Decoder Model](https://arxiv.org/abs/2010.14099), arXiv preprint arXiv:2010.14099, 2020.
-- [Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition](https://arxiv.org/abs/2206.08317), INTERSPEECH 2022.
\ No newline at end of file
+- [San-m: Memory equipped self-attention for end-to-end speech recognition](https://arxiv.org/pdf/2006.01713), INTERSPEECH 2020
+- [Streaming Chunk-Aware Multihead Attention for Online End-to-End Speech Recognition](https://arxiv.org/abs/2006.01712), INTERSPEECH 2020
+- [Conformer: Convolution-augmented Transformer for Speech Recognition](https://arxiv.org/abs/2005.08100), INTERSPEECH 2020
+- [Sequence-to-sequence learning with Transducers](https://arxiv.org/pdf/1211.3711.pdf), NIPS 2016
+
+
+### Multi-talker Speech Recognition
+- [MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario](https://arxiv.org/abs/2210.05265), ICASSP 2022
+
+### Voice Activity Detection
+- [Deep-FSMN for Large Vocabulary Continuous Speech Recognition](https://arxiv.org/abs/1803.05030), ICASSP 2018
+
+### Punctuation Restoration
+- [CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection](https://arxiv.org/pdf/2003.01309.pdf), ICASSP 2018
+
+### Language Models
+- [Attention Is All You Need](https://arxiv.org/abs/1706.03762), NEURIPS 2017
+
+### Speaker Verification
+- [X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION](https://www.danielpovey.com/files/2018_icassp_xvectors.pdf), ICASSP 2018
+
+### Speaker diarization
+- [Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis](https://arxiv.org/abs/2211.10243), EMNLP 2022
+- [TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization](https://arxiv.org/abs/2303.05397), ICASSP 2023
+
+### Timestamp Prediction
+- [Achieving Timestamp Prediction While Recognizing with Non-Autoregressive End-to-End ASR Model](https://arxiv.org/abs/2301.12343), arXiv:2301.12343
diff --git a/docs/get_started.md b/docs/recipe/asr_recipe.md
similarity index 98%
copy from docs/get_started.md
copy to docs/recipe/asr_recipe.md
index 4a7d86e..f82a6fe 100644
--- a/docs/get_started.md
+++ b/docs/recipe/asr_recipe.md
@@ -1,4 +1,4 @@
-# Get Started
+# Speech Recognition
Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
## Overall Introduction
diff --git a/docs/get_started.md b/docs/recipe/lm_recipe.md
similarity index 98%
copy from docs/get_started.md
copy to docs/recipe/lm_recipe.md
index 4a7d86e..f82a6fe 100644
--- a/docs/get_started.md
+++ b/docs/recipe/lm_recipe.md
@@ -1,4 +1,4 @@
-# Get Started
+# Speech Recognition
Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
## Overall Introduction
diff --git a/docs/get_started.md b/docs/recipe/punc_recipe.md
similarity index 98%
copy from docs/get_started.md
copy to docs/recipe/punc_recipe.md
index 4a7d86e..0306cd3 100644
--- a/docs/get_started.md
+++ b/docs/recipe/punc_recipe.md
@@ -1,4 +1,4 @@
-# Get Started
+# Punctuation Restoration
Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
## Overall Introduction
diff --git a/docs/get_started.md b/docs/recipe/sv_recipe.md
similarity index 98%
rename from docs/get_started.md
rename to docs/recipe/sv_recipe.md
index 4a7d86e..0eebe3d 100644
--- a/docs/get_started.md
+++ b/docs/recipe/sv_recipe.md
@@ -1,4 +1,4 @@
-# Get Started
+# Speaker Verification
Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
## Overall Introduction
diff --git a/docs/get_started.md b/docs/recipe/vad_recipe.md
similarity index 98%
copy from docs/get_started.md
copy to docs/recipe/vad_recipe.md
index 4a7d86e..6aa7532 100644
--- a/docs/get_started.md
+++ b/docs/recipe/vad_recipe.md
@@ -1,4 +1,4 @@
-# Get Started
+# Voice Activity Detection
Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
## Overall Introduction
diff --git a/docs/runtime/export.md b/docs/runtime/export.md
new file mode 120000
index 0000000..91f8b98
--- /dev/null
+++ b/docs/runtime/export.md
@@ -0,0 +1 @@
+../../funasr/export/README.md
\ No newline at end of file
diff --git a/docs/runtime/grpc_cpp.md b/docs/runtime/grpc_cpp.md
new file mode 120000
index 0000000..590a5f7
--- /dev/null
+++ b/docs/runtime/grpc_cpp.md
@@ -0,0 +1 @@
+../../funasr/runtime/grpc/Readme.md
\ No newline at end of file
diff --git a/docs/runtime/grpc_python.md b/docs/runtime/grpc_python.md
new file mode 120000
index 0000000..ee8d6ea
--- /dev/null
+++ b/docs/runtime/grpc_python.md
@@ -0,0 +1 @@
+../../funasr/runtime/python/grpc/Readme.md
\ No newline at end of file
diff --git a/docs/runtime/libtorch_python.md b/docs/runtime/libtorch_python.md
new file mode 120000
index 0000000..e8d6288
--- /dev/null
+++ b/docs/runtime/libtorch_python.md
@@ -0,0 +1 @@
+../../funasr/runtime/python/libtorch/README.md
\ No newline at end of file
diff --git a/docs/runtime/onnxruntime_cpp.md b/docs/runtime/onnxruntime_cpp.md
new file mode 120000
index 0000000..3661d18
--- /dev/null
+++ b/docs/runtime/onnxruntime_cpp.md
@@ -0,0 +1 @@
+../../funasr/runtime/onnxruntime/readme.md
\ No newline at end of file
diff --git a/docs/runtime/onnxruntime_python.md b/docs/runtime/onnxruntime_python.md
new file mode 120000
index 0000000..693bd5d
--- /dev/null
+++ b/docs/runtime/onnxruntime_python.md
@@ -0,0 +1 @@
+../../funasr/runtime/python/onnxruntime/README.md
\ No newline at end of file
diff --git a/docs/runtime/websocket_python.md b/docs/runtime/websocket_python.md
new file mode 120000
index 0000000..0fabb85
--- /dev/null
+++ b/docs/runtime/websocket_python.md
@@ -0,0 +1 @@
+../../funasr/runtime/python/websocket/README.md
\ No newline at end of file
diff --git a/egs/aishell/rnnt/README.md b/egs/aishell/rnnt/README.md
new file mode 100644
index 0000000..45f1f3f
--- /dev/null
+++ b/egs/aishell/rnnt/README.md
@@ -0,0 +1,18 @@
+
+# Streaming RNN-T Result
+
+## Training Config
+- 8 gpu(Tesla V100)
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train config: conf/train_conformer_rnnt_unified.yaml
+- chunk config: chunk size 16, full left chunk
+- LM config: LM was not used
+- Model size: 90M
+
+## Results (CER)
+- Decode config: conf/train_conformer_rnnt_unified.yaml
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 5.53 |
+| test | 6.24 |
diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml
new file mode 100644
index 0000000..26e43c6
--- /dev/null
+++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml
@@ -0,0 +1,8 @@
+# The conformer transducer decoding configuration from @jeon30c
+beam_size: 10
+simu_streaming: false
+streaming: true
+chunk_size: 16
+left_context: 16
+right_context: 0
+
diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml
new file mode 100644
index 0000000..dc3eff2
--- /dev/null
+++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml
@@ -0,0 +1,5 @@
+# The conformer transducer decoding configuration from @jeon30c
+beam_size: 10
+simu_streaming: true
+streaming: false
+chunk_size: 16
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
new file mode 100644
index 0000000..8a1c40c
--- /dev/null
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -0,0 +1,80 @@
+encoder: chunk_conformer
+encoder_conf:
+ activation_type: swish
+ positional_dropout_rate: 0.5
+ time_reduction_factor: 2
+ unified_model_training: true
+ default_chunk_size: 16
+ jitter_range: 4
+ left_chunk_size: 0
+ embed_vgg_like: false
+ subsampling_factor: 4
+ linear_units: 2048
+ output_size: 512
+ attention_heads: 8
+ dropout_rate: 0.5
+ positional_dropout_rate: 0.5
+ attention_dropout_rate: 0.5
+ cnn_module_kernel: 15
+ num_blocks: 12
+
+# decoder related
+rnnt_decoder: rnnt
+rnnt_decoder_conf:
+ embed_size: 512
+ hidden_size: 512
+ embed_dropout_rate: 0.5
+ dropout_rate: 0.5
+
+joint_network_conf:
+ joint_space_size: 512
+
+# Auxiliary CTC
+model_conf:
+ auxiliary_ctc_weight: 0.0
+
+# minibatch related
+use_amp: true
+batch_type: unsorted
+batch_size: 16
+num_workers: 16
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 200
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - cer_transducer_chunk
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+normalize: None
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 40
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 50
+ num_time_mask: 5
+
+log_interval: 50
diff --git a/egs/aishell/rnnt/local/aishell_data_prep.sh b/egs/aishell/rnnt/local/aishell_data_prep.sh
new file mode 100755
index 0000000..83f489b
--- /dev/null
+++ b/egs/aishell/rnnt/local/aishell_data_prep.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Copyright 2017 Xingyu Na
+# Apache 2.0
+
+#. ./path.sh || exit 1;
+
+if [ $# != 3 ]; then
+ echo "Usage: $0 <audio-path> <text-path> <output-path>"
+ echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
+ exit 1;
+fi
+
+aishell_audio_dir=$1
+aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
+
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
+
+mkdir -p $train_dir
+mkdir -p $dev_dir
+mkdir -p $test_dir
+mkdir -p $tmp_dir
+
+# data directory check
+if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
+ echo "Error: $0 requires two directory arguments"
+ exit 1;
+fi
+
+# find wav audio file for train, dev and test resp.
+find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
+n=`cat $tmp_dir/wav.flist | wc -l`
+[ $n -ne 141925 ] && \
+ echo Warning: expected 141925 data data files, found $n
+
+grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
+grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
+grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
+
+rm -r $tmp_dir
+
+# Transcriptions preparation
+for dir in $train_dir $dev_dir $test_dir; do
+ echo Preparing $dir transcriptions
+ sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
+ paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
+ utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
+ awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
+ utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
+ sort -u $dir/transcripts.txt > $dir/text
+done
+
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
+
+for f in wav.scp text; do
+ cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+ cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+ cp $test_dir/$f $output_dir/data/test/$f || exit 1;
+done
+
+echo "$0: AISHELL data preparation succeeded"
+exit 0;
diff --git a/egs/aishell/rnnt/path.sh b/egs/aishell/rnnt/path.sh
new file mode 100644
index 0000000..7972642
--- /dev/null
+++ b/egs/aishell/rnnt/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/aishell/rnnt/run.sh b/egs/aishell/rnnt/run.sh
new file mode 100755
index 0000000..bcd4a8b
--- /dev/null
+++ b/egs/aishell/rnnt/run.sh
@@ -0,0 +1,247 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3"
+gpu_num=4
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir= #feature output dictionary
+exp_dir=
+lang=zh
+dumpdir=dump/fbank
+feats_type=fbank
+token_type=char
+scp=feats.scp
+type=kaldi_ark
+stage=0
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+nj=32
+speed_perturb="0.9,1.0,1.1"
+
+# data
+data_aishell=
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_conformer_rnnt_unified.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_rnnt_conformer_streaming.yaml
+inference_asr_model=valid.cer_transducer_chunk.ave_5best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ for x in train dev test; do
+ cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+ done
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
+feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature Generation"
+ # compute fbank features
+ fbankdir=${feats_dir}/fbank
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
+ ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
+ utils/fix_data_feat.sh ${fbankdir}/train
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+ ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
+ utils/fix_data_feat.sh ${fbankdir}/dev
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+ ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
+ utils/fix_data_feat.sh ${fbankdir}/test
+
+ # compute global cmvn
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
+ ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
+
+ # apply cmvn
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
+
+ cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
+ cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
+ cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
+
+ utils/fix_data_feat.sh ${feat_train_dir}
+ utils/fix_data_feat.sh ${feat_dev_dir}
+ utils/fix_data_feat.sh ${feat_test_dir}
+
+ #generate ark list
+ utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
+ utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ num_token=$(cat ${token_list} | wc -l)
+ echo "<unk>" >> ${token_list}
+ vocab_size=$(cat ${token_list} | wc -l)
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
+ cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
+ cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
+fi
+
+# Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ asr_train_transducer.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --token_type char \
+ --token_list $token_list \
+ --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
+ --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
+ --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
+ --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
+ --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
+ --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
+ --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
+ --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --input_size $feats_dim \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${dumpdir}/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode rnnt \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
diff --git a/egs/aishell/rnnt/utils b/egs/aishell/rnnt/utils
new file mode 120000
index 0000000..4072eac
--- /dev/null
+++ b/egs/aishell/rnnt/utils
@@ -0,0 +1 @@
+../transformer/utils
\ No newline at end of file
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
index 5f4563d..3db6f7d 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
@@ -1,3 +1,9 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+https://arxiv.org/abs/2303.05397
+"""
+
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
index db22c18..db10193 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
@@ -1,3 +1,9 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+https://arxiv.org/abs/2211.10243
+"""
+
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index da1241a..2b6716e 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
@@ -131,6 +134,11 @@
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
+ group.add_argument(
+ "--beam_search_config",
+ default={},
+ help="The keyword arguments for transducer beam search.",
+ )
group = parser.add_argument_group("Beam-search related")
group.add_argument(
@@ -168,6 +176,41 @@
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
+ group.add_argument("--simu_streaming", type=str2bool, default=False)
+ group.add_argument("--chunk_size", type=int, default=16)
+ group.add_argument("--left_context", type=int, default=16)
+ group.add_argument("--right_context", type=int, default=0)
+ group.add_argument(
+ "--display_partial_hypotheses",
+ type=bool,
+ default=False,
+ help="Whether to display partial hypotheses during chunk-by-chunk inference.",
+ )
+
+ group = parser.add_argument_group("Dynamic quantization related")
+ group.add_argument(
+ "--quantize_asr_model",
+ type=bool,
+ default=False,
+ help="Apply dynamic quantization to ASR model.",
+ )
+ group.add_argument(
+ "--quantize_modules",
+ nargs="*",
+ default=None,
+ help="""Module names to apply dynamic quantization on.
+ The module names are provided as a list, where each name is separated
+ by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+ Each specified name should be an attribute of 'torch.nn', e.g.:
+ torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+ )
+ group.add_argument(
+ "--quantize_dtype",
+ type=str,
+ default="qint8",
+ choices=["float16", "qint8"],
+ help="Dtype for dynamic quantization.",
+ )
group = parser.add_argument_group("Text converter related")
group.add_argument(
@@ -265,6 +308,9 @@
elif mode == "mfcca":
from funasr.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
+ elif mode == "rnnt":
+ from funasr.bin.asr_inference_rnnt import inference
+ return inference(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index 66dec39..944685f 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
+import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -607,17 +608,22 @@
):
# 3. Build data-iterator
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
- raw_inputs = _load_bytes(data_path_and_name_and_type[0])
- raw_inputs = torch.tensor(raw_inputs)
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, np.ndarray):
- raw_inputs = torch.tensor(raw_inputs)
is_final = False
+ cache = {}
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
if param_dict is not None and "is_final" in param_dict:
is_final = param_dict["is_final"]
+
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
+ raw_inputs = _load_bytes(data_path_and_name_and_type[0])
+ raw_inputs = torch.tensor(raw_inputs)
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ is_final = True
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, np.ndarray):
+ raw_inputs = torch.tensor(raw_inputs)
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 2189a71..bff8702 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -1,396 +1,149 @@
#!/usr/bin/env python3
+
+""" Inference class definition for Transducer models."""
+
+from __future__ import annotations
+
import argparse
import logging
+import math
import sys
-import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
from pathlib import Path
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
-from typeguard import check_argument_types
+from packaging.version import parse as V
+from typeguard import check_argument_types, check_return_type
+from funasr.modules.beam_search.beam_search_transducer import (
+ BeamSearchTransducer,
+ Hypothesis,
+)
+from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+from funasr.tasks.asr import ASRTransducerTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.types import str2bool, str2triple_str, str_or_none
from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-
class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
+ """Speech2Text class for Transducer models.
+ Args:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model config path.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ device: Device to use for inference.
+ beam_size: Size of beam during search.
+ dtype: Data type.
+ lm_weight: Language model weight.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ nbest: Number of final hypothesis.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
"""
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
- assert check_argument_types()
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ beam_search_config: Dict[str, Any] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ beam_size: int = 5,
+ dtype: str = "float32",
+ lm_weight: float = 1.0,
+ quantize_asr_model: bool = False,
+ quantize_modules: List[str] = None,
+ quantize_dtype: str = "qint8",
+ nbest: int = 1,
+ streaming: bool = False,
+ simu_streaming: bool = False,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ display_partial_hypotheses: bool = False,
+ ) -> None:
+ """Construct a Speech2Text object."""
+ super().__init__()
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ assert check_argument_types()
+ asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
+
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
+ if quantize_asr_model:
+ if quantize_modules is not None:
+ if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
+ raise ValueError(
+ "Only 'Linear' and 'LSTM' modules are currently supported"
+ " by PyTorch and in --quantize_modules"
+ )
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
+ q_config = set([getattr(torch.nn, q) for q in quantize_modules])
+ else:
+ q_config = {torch.nn.Linear}
- # 2. Build Language model
+ if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
+ raise ValueError(
+ "float16 dtype for dynamic quantization is not supported with torch"
+ " version < 1.5.0. Switching to qint8 dtype instead."
+ )
+ q_dtype = getattr(torch, quantize_dtype)
+
+ asr_model = torch.quantization.quantize_dynamic(
+ asr_model, q_config, dtype=q_dtype
+ ).eval()
+ else:
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
+ lm_scorer = lm.lm
+ else:
+ lm_scorer = None
# 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
+ if beam_search_config is None:
+ beam_search_config = {}
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
+ beam_search = BeamSearchTransducer(
+ asr_model.decoder,
+ asr_model.joint_network,
+ beam_size,
+ lm=lm_scorer,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ **beam_search_config,
)
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
- # 6. [Optional] Build hotword list from str, local file or url
- self.hotword_list = None
- self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
-
- is_use_lm = lm_weight != 0.0 and lm_file is not None
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
- beam_search = None
- self.beam_search = beam_search
- logging.info(f"Beam_search: {self.beam_search}")
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, enc_len = self.asr_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
- # assert len(enc) == 1, len(enc)
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
- predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- if not isinstance(self.asr_model, ContextualParaformer):
- if self.hotword_list:
- logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = enc[i, :enc_len[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
-
- # assert check_return_type(results)
- return results
-
- def generate_hotwords_list(self, hotword_list_or_file):
- # for None
- if hotword_list_or_file is None:
- hotword_list = None
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords from local txt...")
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- elif hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for text str input
- elif not hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords as str...")
- hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- else:
- hotword_list = None
- return hotword_list
-
-class Speech2TextExport:
- """Speech2TextExport class
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
-
- # 1. Build ASR model
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
token_list = asr_model.token_list
-
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
@@ -407,197 +160,277 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
- # self.asr_model = asr_model
+
+ self.asr_model = asr_model
self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
self.device = device
self.dtype = dtype
self.nbest = nbest
- self.frontend = frontend
- model = Paraformer_export(asr_model, onnx=False)
- self.asr_model = model
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ self.beam_search = beam_search
+ self.streaming = streaming
+ self.simu_streaming = simu_streaming
+ self.chunk_size = max(chunk_size, 0)
+ self.left_context = max(left_context, 0)
+ self.right_context = max(right_context, 0)
+
+ if not streaming or chunk_size == 0:
+ self.streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+ if not simu_streaming or chunk_size == 0:
+ self.simu_streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+
+ self.frontend = frontend
+ self.window_size = self.chunk_size + self.right_context
+
+ self._ctx = self.asr_model.encoder.get_encoder_input_size(
+ self.window_size
+ )
+
+ #self.last_chunk_length = (
+ # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ #) * self.hop_length
+
+ self.last_chunk_length = (
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ )
+ self.reset_inference_cache()
+
+ def reset_inference_cache(self) -> None:
+ """Reset Speech2Text parameters."""
+ self.frontend_cache = None
+
+ self.asr_model.encoder.reset_streaming_cache(
+ self.left_context, device=self.device
+ )
+ self.beam_search.reset_inference_cache()
+
+ self.num_processed_frames = torch.tensor([[0]], device=self.device)
+
@torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- """Inference
-
+ def streaming_decode(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ is_final: bool = True,
+ ) -> List[Hypothesis]:
+ """Speech2Text streaming call.
Args:
- speech: Input speech data
+ speech: Chunk of speech data. (S)
+ is_final: Whether speech corresponds to the final chunk of data.
Returns:
- text, token, token_int, hyp
+ nbest_hypothesis: N-best hypothesis.
+ """
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if is_final:
+ if self.streaming and speech.size(0) < self.last_chunk_length:
+ pad = torch.zeros(
+ self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
+ )
+ speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final)
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.chunk_forward(
+ feats,
+ feats_lengths,
+ self.num_processed_frames,
+ chunk_size=self.chunk_size,
+ left_context=self.left_context,
+ right_context=self.right_context,
+ )
+ nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
+
+ self.num_processed_frames += self.chunk_size
+
+ if is_final:
+ self.reset_inference_cache()
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
- # Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
-
- enc_len_batch_total = feats_len.sum()
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- decoder_outs = self.asr_model(**batch)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context)
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+
+ enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]:
+ """Build partial or final results from the hypotheses.
+ Args:
+ nbest_hyps: N-best hypothesis.
+ Returns:
+ results: Results containing different representation for the hypothesis.
+ """
results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- am_scores = decoder_out[i, :ys_pad_lens[i], :]
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- yseq.tolist(), device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for hyp in nbest_hyps:
+ token_int = list(filter(lambda x: x != 0, hyp.yseq))
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
+ token = self.converter.ids2tokens(token_int)
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
+ assert check_return_type(results)
return results
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ) -> Speech2Text:
+ """Build Speech2Text instance from the pretrained model.
+ Args:
+ model_tag: Model tag of the pretrained models.
+ Return:
+ : Speech2Text instance.
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2Text(**kwargs)
+
def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
-
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
-
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
+) -> None:
+ """Transducer model inference.
+ Args:
+ output_dir: Output directory path.
+ batch_size: Batch decoding size.
+ dtype: Data type.
+ beam_size: Beam size.
+ ngpu: Number of GPUs.
+ seed: Random number generator seed.
+ lm_weight: Weight of language model.
+ nbest: Number of final hypothesis.
+ num_workers: Number of workers.
+ log_level: Level of verbose for logs.
+ data_path_and_name_and_type:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model path.
+ model_tag: Model tag.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ key_file: File key.
+ allow_variable_data_keys: Whether to allow variable data keys.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
assert check_argument_types()
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
@@ -605,20 +438,11 @@
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
- export_mode = False
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- export_mode = param_dict.get("export_mode", False)
- else:
- hotword_list_or_file = None
- if ngpu >= 1 and torch.cuda.is_available():
+ if ngpu >= 1:
device = "cuda"
else:
device = "cpu"
- batch_size = 1
-
# 1. Set random-seed
set_all_random_seed(seed)
@@ -627,143 +451,105 @@
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
+ beam_search_config=beam_search_config,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
- ctc_weight=ctc_weight,
lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
+ quantize_asr_model=quantize_asr_model,
+ quantize_modules=quantize_modules,
+ quantize_dtype=quantize_dtype,
+ streaming=streaming,
+ simu_streaming=simu_streaming,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
)
- if export_mode:
- speech2text = Speech2TextExport(**speech2text_kwargs)
- else:
- speech2text = Speech2Text(**speech2text_kwargs)
+ speech2text = Speech2Text.from_pretrained(
+ model_tag=model_tag,
+ **speech2text_kwargs,
+ )
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
+ # 3. Build data-iterator
+ loader = ASRTransducerTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTransducerTask.build_preprocess_fn(
+ speech2text.asr_train_args, False
+ ),
+ collate_fn=ASRTransducerTask.build_collate_fn(
+ speech2text.asr_train_args, False
+ ),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- if 'hotword' in kwargs:
- hotword_list_or_file = kwargs['hotword']
- if hotword_list_or_file is not None or 'hotword' in kwargs:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
- cache = None
- if 'cache' in param_dict:
- cache = param_dict['cache']
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- forward_time_total = 0.0
- length_total = 0.0
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
+ # 4 .Start for-loop
+ with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
+
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ assert len(batch.keys()) == 1
- logging.info("decoding, utt_id: {}".format(keys))
- # N-best list of (text, token, token_int, hyp_object)
+ try:
+ if speech2text.streaming:
+ speech = batch["speech"]
- time_beg = time.time()
- results = speech2text(cache=cache, **batch)
- if len(results) < 1:
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
- time_end = time.time()
- forward_time = time_end - time_beg
- lfr_factor = results[0][-1]
- length = results[0][-2]
- forward_time_total += forward_time
- length_total += length
- rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
- logging.info(rtf_cur)
+ _steps = len(speech) // speech2text._ctx
+ _end = 0
+ for i in range(_steps):
+ _end = (i + 1) * speech2text._ctx
- for batch_id in range(_bs):
- result = [results[batch_id][:-2]]
+ speech2text.streaming_decode(
+ speech[i * speech2text._ctx : _end], is_final=False
+ )
- key = keys[batch_id]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
+ final_hyps = speech2text.streaming_decode(
+ speech[_end : len(speech)], is_final=True
+ )
+ elif speech2text.simu_streaming:
+ final_hyps = speech2text.simu_streaming_decode(**batch)
+ else:
+ final_hyps = speech2text(**batch)
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["rtf"][key] = rtf_cur
+ results = speech2text.hypotheses_to_results(final_hyps)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
- if text is not None:
- text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ ibest_writer = writer[f"{n}best_recog"]
- logging.info("decoding, utt: {}, predictions: {}".format(key, text))
- rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
- logging.info(rtf_avg)
- if writer is not None:
- ibest_writer["rtf"]["rtf_avf"] = rtf_avg
- return asr_result_list
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
- return _forward
+ if text is not None:
+ ibest_writer["text"][key] = text
def get_parser():
+ """Get Transducer model inference parser."""
+
parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
+ description="ASR Transducer Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
@@ -792,17 +578,12 @@
default=1,
help="The number of workers used for DataLoader",
)
- parser.add_argument(
- "--hotword",
- type=str_or_none,
- default=None,
- help="hotword file path or hotwords seperated by space"
- )
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
- required=False,
+ required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
@@ -835,25 +616,10 @@
help="LM parameter file",
)
group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
+ "*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
@@ -864,42 +630,13 @@
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
+ group.add_argument("--beam_size", type=int, default=5, help="Beam size")
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
group.add_argument(
- "--frontend_conf",
- default=None,
- help="",
+ "--beam_search_config",
+ default={},
+ help="The keyword arguments for transducer beam search.",
)
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
@@ -908,14 +645,77 @@
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
- "If not given, refers from the training args",
+ "If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
- "If not given, refers from the training args",
+ "If not given, refers from the training args",
+ )
+
+ group = parser.add_argument_group("Dynamic quantization related")
+ parser.add_argument(
+ "--quantize_asr_model",
+ type=bool,
+ default=False,
+ help="Apply dynamic quantization to ASR model.",
+ )
+ parser.add_argument(
+ "--quantize_modules",
+ nargs="*",
+ default=None,
+ help="""Module names to apply dynamic quantization on.
+ The module names are provided as a list, where each name is separated
+ by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+ Each specified name should be an attribute of 'torch.nn', e.g.:
+ torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+ )
+ parser.add_argument(
+ "--quantize_dtype",
+ type=str,
+ default="qint8",
+ choices=["float16", "qint8"],
+ help="Dtype for dynamic quantization.",
+ )
+
+ group = parser.add_argument_group("Streaming related")
+ parser.add_argument(
+ "--streaming",
+ type=bool,
+ default=False,
+ help="Whether to perform chunk-by-chunk inference.",
+ )
+ parser.add_argument(
+ "--simu_streaming",
+ type=bool,
+ default=False,
+ help="Whether to simulate chunk-by-chunk inference.",
+ )
+ parser.add_argument(
+ "--chunk_size",
+ type=int,
+ default=16,
+ help="Number of frames in chunk AFTER subsampling.",
+ )
+ parser.add_argument(
+ "--left_context",
+ type=int,
+ default=32,
+ help="Number of frames in left context of the chunk AFTER subsampling.",
+ )
+ parser.add_argument(
+ "--right_context",
+ type=int,
+ default=0,
+ help="Number of frames in right context of the chunk AFTER subsampling.",
+ )
+ parser.add_argument(
+ "--display_partial_hypotheses",
+ type=bool,
+ default=False,
+ help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
return parser
@@ -923,24 +723,15 @@
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
+
parser = get_parser()
args = parser.parse_args(cmd)
- param_dict = {'hotword': args.hotword}
kwargs = vars(args)
+
kwargs.pop("config", None)
- kwargs['param_dict'] = param_dict
inference(**kwargs)
if __name__ == "__main__":
main()
- # from modelscope.pipelines import pipeline
- # from modelscope.utils.constant import Tasks
- #
- # inference_16k_pipline = pipeline(
- # task=Tasks.auto_speech_recognition,
- # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
- #
- # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
- # print(rec_result)
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
new file mode 100755
index 0000000..fe418db
--- /dev/null
+++ b/funasr/bin/asr_train_transducer.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.asr import ASRTransducerTask
+
+
+# for ASR Training
+def parse_args():
+ parser = ASRTransducerTask.get_parser()
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args=None, cmd=None):
+ # for ASR Training
+ ASRTransducerTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small":
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 85e4518..83436e8 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index 492ebab..d229cc6 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index e7e3f15..2c5a286 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/punc_train_vadrealtime.py b/funasr/bin/punc_train_vadrealtime.py
deleted file mode 100644
index c5afaad..0000000
--- a/funasr/bin/punc_train_vadrealtime.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#!/usr/bin/env python3
-import os
-from funasr.tasks.punctuation import PunctuationTask
-
-
-def parse_args():
- parser = PunctuationTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- parser.add_argument(
- "--punc_list",
- type=str,
- default=None,
- help="Punctuation list",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- """
- punc training.
- """
- PunctuationTask.main(args=args, cmd=cmd)
-
-
-if __name__ == "__main__":
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- main(args=args)
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index 81f9d7a..5157eeb 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -90,7 +90,7 @@
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+ "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 1205d19..64a3cff 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index dd76df6..55debac 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index 42c5c1e..8fea8db 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import torch
+torch.set_num_threads(1)
+
import argparse
import logging
import os
diff --git a/funasr/datasets/dataset.py b/funasr/datasets/dataset.py
index 1595224..979479c 100644
--- a/funasr/datasets/dataset.py
+++ b/funasr/datasets/dataset.py
@@ -115,7 +115,7 @@
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
# like Kaldi e.g. "cat a.wav |".
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
- loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+ loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
# but ndarray is desired, so Adapter class is inserted here
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index d8ceff2..022d321 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -19,6 +19,7 @@
def seg_tokenize(txt, seg_dict):
out_txt = ""
for word in txt:
+ word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
@@ -41,14 +42,13 @@
if seg_dict is not None:
assert isinstance(seg_dict, dict)
- txt = forward_segment("".join(text).lower(), seg_dict)
- text = seg_tokenize(txt, seg_dict)
+ text = seg_tokenize(text, seg_dict)
length = len(text)
for i in range(length):
x = text[i]
- if i == length-1 and "punc" in data and text[i].startswith("vad:"):
- vad = x[-1][4:]
+ if i == length-1 and "punc" in data and x.startswith("vad:"):
+ vad = x[4:]
if len(vad) == 0:
vad = -1
else:
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index afeff4e..20a3831 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -48,6 +48,7 @@
def seg_tokenize(txt, seg_dict):
out_txt = ""
for word in txt:
+ word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
@@ -359,7 +360,6 @@
if self.split_with_space:
tokens = text.strip().split(" ")
if self.seg_dict is not None:
- tokens = forward_segment("".join(tokens), self.seg_dict)
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = self.tokenizer.text2tokens(text)
@@ -786,6 +786,7 @@
) -> Dict[str, np.ndarray]:
for i in range(self.num_tokenizer):
text_name = self.text_name[i]
+ #import pdb; pdb.set_trace()
if text_name in data and self.tokenizer[i] is not None:
text = data[text_name]
text = self.text_cleaner(text)
@@ -800,7 +801,7 @@
data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
-
+ return data
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
@@ -813,4 +814,4 @@
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
- return sentences
\ No newline at end of file
+ return sentences
diff --git a/funasr/export/README.md b/funasr/export/README.md
index 97a3de9..8f57673 100644
--- a/funasr/export/README.md
+++ b/funasr/export/README.md
@@ -1,3 +1,4 @@
+# Export models
## Environments
torch >= 1.11.0
@@ -7,7 +8,7 @@
## Install modelscope and funasr
-The installation is the same as [funasr](../../README.md)
+The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
## Export model
`Tips`: torch>=1.11.0
diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/CT_Transformer.py
similarity index 90%
rename from funasr/export/models/target_delay_transformer.py
rename to funasr/export/models/CT_Transformer.py
index 2780d82..932e3af 100644
--- a/funasr/export/models/target_delay_transformer.py
+++ b/funasr/export/models/CT_Transformer.py
@@ -9,7 +9,11 @@
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class CT_Transformer(nn.Module):
-
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(
self,
model,
@@ -76,7 +80,11 @@
class CT_Transformer_VadRealtime(nn.Module):
-
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(
self,
model,
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index f81ff64..0e3a782 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -4,10 +4,10 @@
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.models.target_delay_transformer import TargetDelayTransformer
-from funasr.export.models.target_delay_transformer import CT_Transformer as CT_Transformer_export
+from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
-from funasr.export.models.target_delay_transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
+from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 0db61e0..52ad320 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -19,7 +19,7 @@
class Paraformer(nn.Module):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@@ -112,7 +112,7 @@
class BiCifParaformer(nn.Module):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
index 3b462e7..78105ab 100644
--- a/funasr/models/decoder/contextual_decoder.py
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -102,7 +102,7 @@
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
new file mode 100644
index 0000000..5401ab2
--- /dev/null
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -0,0 +1,258 @@
+"""RNN decoder definition for Transducer models."""
+
+from typing import List, Optional, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.specaug.specaug import SpecAug
+
+class RNNTDecoder(torch.nn.Module):
+ """RNN decoder module.
+
+ Args:
+ vocab_size: Vocabulary size.
+ embed_size: Embedding size.
+ hidden_size: Hidden size..
+ rnn_type: Decoder layers type.
+ num_layers: Number of decoder layers.
+ dropout_rate: Dropout rate for decoder layers.
+ embed_dropout_rate: Dropout rate for embedding layer.
+ embed_pad: Embedding padding symbol ID.
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ embed_size: int = 256,
+ hidden_size: int = 256,
+ rnn_type: str = "lstm",
+ num_layers: int = 1,
+ dropout_rate: float = 0.0,
+ embed_dropout_rate: float = 0.0,
+ embed_pad: int = 0,
+ ) -> None:
+ """Construct a RNNDecoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ if rnn_type not in ("lstm", "gru"):
+ raise ValueError(f"Not supported: rnn_type={rnn_type}")
+
+ self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
+ self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
+
+ rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
+
+ self.rnn = torch.nn.ModuleList(
+ [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
+ )
+
+ for _ in range(1, num_layers):
+ self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
+
+ self.dropout_rnn = torch.nn.ModuleList(
+ [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
+ )
+
+ self.dlayers = num_layers
+ self.dtype = rnn_type
+
+ self.output_size = hidden_size
+ self.vocab_size = vocab_size
+
+ self.device = next(self.parameters()).device
+ self.score_cache = {}
+
+ def forward(
+ self,
+ labels: torch.Tensor,
+ label_lens: torch.Tensor,
+ states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
+ ) -> torch.Tensor:
+ """Encode source label sequences.
+
+ Args:
+ labels: Label ID sequences. (B, L)
+ states: Decoder hidden states.
+ ((N, B, D_dec), (N, B, D_dec) or None) or None
+
+ Returns:
+ dec_out: Decoder output sequences. (B, U, D_dec)
+
+ """
+ if states is None:
+ states = self.init_state(labels.size(0))
+
+ dec_embed = self.dropout_embed(self.embed(labels))
+ dec_out, states = self.rnn_forward(dec_embed, states)
+ return dec_out
+
+ def rnn_forward(
+ self,
+ x: torch.Tensor,
+ state: Tuple[torch.Tensor, Optional[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+ """Encode source label sequences.
+
+ Args:
+ x: RNN input sequences. (B, D_emb)
+ state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ Returns:
+ x: RNN output sequences. (B, D_dec)
+ (h_next, c_next): Decoder hidden states.
+ (N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ h_prev, c_prev = state
+ h_next, c_next = self.init_state(x.size(0))
+
+ for layer in range(self.dlayers):
+ if self.dtype == "lstm":
+ x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
+ layer
+ ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
+ else:
+ x, h_next[layer : layer + 1] = self.rnn[layer](
+ x, hx=h_prev[layer : layer + 1]
+ )
+
+ x = self.dropout_rnn[layer](x)
+
+ return x, (h_next, c_next)
+
+ def score(
+ self,
+ label: torch.Tensor,
+ label_sequence: List[int],
+ dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+ """One-step forward hypothesis.
+
+ Args:
+ label: Previous label. (1, 1)
+ label_sequence: Current label sequence.
+ dec_state: Previous decoder hidden states.
+ ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+ Returns:
+ dec_out: Decoder output sequence. (1, D_dec)
+ dec_state: Decoder hidden states.
+ ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+ """
+ str_labels = "_".join(map(str, label_sequence))
+
+ if str_labels in self.score_cache:
+ dec_out, dec_state = self.score_cache[str_labels]
+ else:
+ dec_embed = self.embed(label)
+ dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
+
+ self.score_cache[str_labels] = (dec_out, dec_state)
+
+ return dec_out[0], dec_state
+
+ def batch_score(
+ self,
+ hyps: List[Hypothesis],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+ """One-step forward hypotheses.
+
+ Args:
+ hyps: Hypotheses.
+
+ Returns:
+ dec_out: Decoder output sequences. (B, D_dec)
+ states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
+ dec_embed = self.embed(labels)
+
+ states = self.create_batch_states([h.dec_state for h in hyps])
+ dec_out, states = self.rnn_forward(dec_embed, states)
+
+ return dec_out.squeeze(1), states
+
+ def set_device(self, device: torch.device) -> None:
+ """Set GPU device to use.
+
+ Args:
+ device: Device ID.
+
+ """
+ self.device = device
+
+ def init_state(
+ self, batch_size: int
+ ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
+ """Initialize decoder states.
+
+ Args:
+ batch_size: Batch size.
+
+ Returns:
+ : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ h_n = torch.zeros(
+ self.dlayers,
+ batch_size,
+ self.output_size,
+ device=self.device,
+ )
+
+ if self.dtype == "lstm":
+ c_n = torch.zeros(
+ self.dlayers,
+ batch_size,
+ self.output_size,
+ device=self.device,
+ )
+
+ return (h_n, c_n)
+
+ return (h_n, None)
+
+ def select_state(
+ self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Get specified ID state from decoder hidden states.
+
+ Args:
+ states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+ idx: State ID to extract.
+
+ Returns:
+ : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+ """
+ return (
+ states[0][:, idx : idx + 1, :],
+ states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
+ )
+
+ def create_batch_states(
+ self,
+ new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Create decoder hidden states.
+
+ Args:
+ new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
+
+ Returns:
+ states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ return (
+ torch.cat([s[0] for s in new_states], dim=1),
+ torch.cat([s[1] for s in new_states], dim=1)
+ if self.dtype == "lstm"
+ else None,
+ )
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index 463918a..18cd343 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -151,7 +151,7 @@
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -812,7 +812,7 @@
class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py
index 5f1bb24..aed7f20 100644
--- a/funasr/models/decoder/transformer_decoder.py
+++ b/funasr/models/decoder/transformer_decoder.py
@@ -405,7 +405,7 @@
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index 0336133..f22f12a 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -36,7 +36,11 @@
import random
import math
class MFCCA(AbsESPnetModel):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """
+ Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
+ MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
+ https://arxiv.org/abs/2210.05265
+ """
def __init__(
self,
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index f1bb2bf..699d85f 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -44,7 +44,7 @@
class Paraformer(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@@ -325,67 +325,12 @@
return encoder_out, encoder_out_lens
- def encode_chunk(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
-
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
-
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
- feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
-
- return encoder_out, torch.tensor([encoder_out.size(1)])
-
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
- def calc_predictor_chunk(self, encoder_out, cache=None):
-
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -396,14 +341,6 @@
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
-
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
- decoder_outs = self.decoder.forward_chunk(
- encoder_out, sematic_embeds, cache["decoder"]
- )
- decoder_out = decoder_outs
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
@@ -610,9 +547,187 @@
return loss_ctc, cer_ctc
-class ParaformerBert(Paraformer):
+class ParaformerOnline(Paraformer):
"""
Author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self, *args, **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ self.step_cur += 1
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ loss_pre = None
+ stats = dict()
+
+ # 1. CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # 2b. Attention decoder branch
+ if self.ctc_weight != 1.0:
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att + loss_pre * self.predictor_weight
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode_chunk(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+ feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ # Post-encoder, e.g. NLU
+ if self.postencoder is not None:
+ encoder_out, encoder_out_lens = self.postencoder(
+ encoder_out, encoder_out_lens
+ )
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+
+ return encoder_out, torch.tensor([encoder_out.size(1)])
+
+ def calc_predictor_chunk(self, encoder_out, cache=None):
+
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \
+ self.predictor.forward_chunk(encoder_out, cache["encoder"])
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+ def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+ decoder_outs = self.decoder.forward_chunk(
+ encoder_out, sematic_embeds, cache["decoder"]
+ )
+ decoder_out = decoder_outs
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out
+
+
+class ParaformerBert(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
"""
@@ -977,6 +1092,59 @@
loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
return loss_pre2
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
def calc_predictor(self, encoder_out, encoder_out_lens):
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
new file mode 100644
index 0000000..0cae306
--- /dev/null
+++ b/funasr/models/e2e_asr_transducer.py
@@ -0,0 +1,1013 @@
+"""ESPnet2 ASR Transducer model."""
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from typeguard import check_argument_types
+
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
+from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if V(torch.__version__) >= V("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class TransducerModel(AbsESPnetModel):
+ """ESPnet2ASRTransducerModel module definition.
+
+ Args:
+ vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+ token_list: List of token
+ frontend: Frontend module.
+ specaug: SpecAugment module.
+ normalize: Normalization module.
+ encoder: Encoder module.
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ transducer_weight: Weight of the Transducer loss.
+ fastemit_lambda: FastEmit lambda value.
+ auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+ auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+ auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+ auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+ ignore_id: Initial padding ID.
+ sym_space: Space symbol.
+ sym_blank: Blank Symbol
+ report_cer: Whether to report Character Error Rate during validation.
+ report_wer: Whether to report Word Error Rate during validation.
+ extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: Encoder,
+ decoder: RNNTDecoder,
+ joint_network: JointNetwork,
+ att_decoder: Optional[AbsAttDecoder] = None,
+ transducer_weight: float = 1.0,
+ fastemit_lambda: float = 0.0,
+ auxiliary_ctc_weight: float = 0.0,
+ auxiliary_ctc_dropout_rate: float = 0.0,
+ auxiliary_lm_loss_weight: float = 0.0,
+ auxiliary_lm_loss_smoothing: float = 0.0,
+ ignore_id: int = -1,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ report_cer: bool = True,
+ report_wer: bool = True,
+ extract_feats_in_collect_stats: bool = True,
+ ) -> None:
+ """Construct an ESPnetASRTransducerModel object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+ self.blank_id = 0
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.token_list = token_list.copy()
+
+ self.sym_space = sym_space
+ self.sym_blank = sym_blank
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Forward architecture and compute loss(es).
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+
+ Return:
+ loss: Main loss value.
+ stats: Task statistics.
+ weight: Task weights.
+
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+ batch_size = speech.shape[0]
+ text = text[:, : text_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # 2. Transducer-related I/O preparation
+ decoder_in, target, t_len, u_len = get_transducer_task_io(
+ text,
+ encoder_out_lens,
+ ignore_id=self.ignore_id,
+ )
+
+ # 3. Decoder
+ self.decoder.set_device(encoder_out.device)
+ decoder_out = self.decoder(decoder_in, u_len)
+
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ # 5. Losses
+ loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_lm = 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ """Collect features sequences and features lengths sequences.
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+
+ Return:
+ {}: "feats": Features sequences. (B, T, D_feats),
+ "feats_lengths": Features sequences lengths. (B,)
+
+ """
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+
+ feats, feats_lengths = speech, speech_lengths
+
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder speech sequences.
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+
+ Return:
+ encoder_out: Encoder outputs. (B, T, D_enc)
+ encoder_out_lens: Encoder outputs lengths. (B,)
+
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # 4. Forward encoder
+ encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract features sequences and features sequences lengths.
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+
+ Return:
+ feats: Features sequences. (B, T, D_feats)
+ feats_lengths: Features sequences lengths. (B,)
+
+ """
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats, feats_lengths = speech, speech_lengths
+
+ return feats, feats_lengths
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
+
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
+
+ """
+ if self.criterion_transducer is None:
+ try:
+ # from warprnnt_pytorch import RNNTLoss
+ # self.criterion_transducer = RNNTLoss(
+ # reduction="mean",
+ # fastemit_lambda=self.fastemit_lambda,
+ # )
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ # loss_transducer = self.criterion_transducer(
+ # joint_out,
+ # target,
+ # t_len,
+ # u_len,
+ # )
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
+ if not self.training and (self.report_cer or self.report_wer):
+ if self.error_calculator is None:
+ from espnet2.asr_transducer.error_calculator import ErrorCalculator
+
+ self.error_calculator = ErrorCalculator(
+ self.decoder,
+ self.joint_network,
+ self.token_list,
+ self.sym_space,
+ self.sym_blank,
+ report_cer=self.report_cer,
+ report_wer=self.report_wer,
+ )
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+
+ Return:
+ loss_ctc: CTC loss value.
+
+ """
+ ctc_in = self.ctc_lin(
+ torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+ )
+ ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+ target_mask = target != 0
+ ctc_target = target[target_mask].cpu()
+
+ with torch.backends.cudnn.flags(deterministic=True):
+ loss_ctc = torch.nn.functional.ctc_loss(
+ ctc_in,
+ ctc_target,
+ t_len,
+ u_len,
+ zero_infinity=True,
+ reduction="sum",
+ )
+ loss_ctc /= target.size(0)
+
+ return loss_ctc
+
+ def _calc_lm_loss(
+ self,
+ decoder_out: torch.Tensor,
+ target: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute LM loss.
+
+ Args:
+ decoder_out: Decoder output sequences. (B, U, D_dec)
+ target: Target label ID sequences. (B, L)
+
+ Return:
+ loss_lm: LM loss value.
+
+ """
+ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+ lm_target = target.view(-1).type(torch.int64)
+
+ with torch.no_grad():
+ true_dist = lm_loss_in.clone()
+ true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+ # Ignore blank ID (0)
+ ignore = lm_target == 0
+ lm_target = lm_target.masked_fill(ignore, 0)
+
+ true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+ loss_lm = torch.nn.functional.kl_div(
+ torch.log_softmax(lm_loss_in, dim=1),
+ true_dist,
+ reduction="none",
+ )
+ loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+ 0
+ )
+
+ return loss_lm
+
+class UnifiedTransducerModel(AbsESPnetModel):
+ """ESPnet2ASRTransducerModel module definition.
+ Args:
+ vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+ token_list: List of token
+ frontend: Frontend module.
+ specaug: SpecAugment module.
+ normalize: Normalization module.
+ encoder: Encoder module.
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ transducer_weight: Weight of the Transducer loss.
+ fastemit_lambda: FastEmit lambda value.
+ auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+ auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+ auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+ auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+ ignore_id: Initial padding ID.
+ sym_space: Space symbol.
+ sym_blank: Blank Symbol
+ report_cer: Whether to report Character Error Rate during validation.
+ report_wer: Whether to report Word Error Rate during validation.
+ extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: Encoder,
+ decoder: RNNTDecoder,
+ joint_network: JointNetwork,
+ att_decoder: Optional[AbsAttDecoder] = None,
+ transducer_weight: float = 1.0,
+ fastemit_lambda: float = 0.0,
+ auxiliary_ctc_weight: float = 0.0,
+ auxiliary_att_weight: float = 0.0,
+ auxiliary_ctc_dropout_rate: float = 0.0,
+ auxiliary_lm_loss_weight: float = 0.0,
+ auxiliary_lm_loss_smoothing: float = 0.0,
+ ignore_id: int = -1,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_sos: str = "<sos/eos>",
+ sym_eos: str = "<sos/eos>",
+ extract_feats_in_collect_stats: bool = True,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ ) -> None:
+ """Construct an ESPnetASRTransducerModel object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+ self.blank_id = 0
+
+ if sym_sos in token_list:
+ self.sos = token_list.index(sym_sos)
+ else:
+ self.sos = vocab_size - 1
+ if sym_eos in token_list:
+ self.eos = token_list.index(sym_eos)
+ else:
+ self.eos = vocab_size - 1
+
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.token_list = token_list.copy()
+
+ self.sym_space = sym_space
+ self.sym_blank = sym_blank
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_att = auxiliary_att_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_att:
+ self.att_decoder = att_decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_att_weight = auxiliary_att_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Forward architecture and compute loss(es).
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+ Return:
+ loss: Main loss value.
+ stats: Task statistics.
+ weight: Task weights.
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+ batch_size = speech.shape[0]
+ text = text[:, : text_lengths.max()]
+ #print(speech.shape)
+ # 1. Encoder
+ encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_att, loss_att_chunk = 0.0, 0.0
+
+ if self.use_auxiliary_att:
+ loss_att, _ = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+ loss_att_chunk, _ = self._calc_att_loss(
+ encoder_out_chunk, encoder_out_lens, text, text_lengths
+ )
+
+ # 2. Transducer-related I/O preparation
+ decoder_in, target, t_len, u_len = get_transducer_task_io(
+ text,
+ encoder_out_lens,
+ ignore_id=self.ignore_id,
+ )
+
+ # 3. Decoder
+ self.decoder.set_device(encoder_out.device)
+ decoder_out = self.decoder(decoder_in, u_len)
+
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ joint_out_chunk = self.joint_network(
+ encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ # 5. Losses
+ loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
+ encoder_out_chunk,
+ joint_out_chunk,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
+ )
+ loss_ctc_chunk = self._calc_ctc_loss(
+ encoder_out_chunk,
+ target,
+ t_len,
+ u_len,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss_trans = loss_trans_utt + loss_trans_chunk
+ loss_ctc = loss_ctc + loss_ctc_chunk
+ loss_ctc = loss_att + loss_att_chunk
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_att_weight * loss_att
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans_utt.detach(),
+ loss_transducer_chunk=loss_trans_chunk.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
+ aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
+ aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ cer_transducer_chunk=cer_trans_chunk,
+ wer_transducer_chunk=wer_trans_chunk,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ """Collect features sequences and features lengths sequences.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+ Return:
+ {}: "feats": Features sequences. (B, T, D_feats),
+ "feats_lengths": Features sequences lengths. (B,)
+ """
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+
+ feats, feats_lengths = speech, speech_lengths
+
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder speech sequences.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ Return:
+ encoder_out: Encoder outputs. (B, T, D_enc)
+ encoder_out_lens: Encoder outputs lengths. (B,)
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # 4. Forward encoder
+ encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_chunk, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract features sequences and features sequences lengths.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ Return:
+ feats: Features sequences. (B, T, D_feats)
+ feats_lengths: Features sequences lengths. (B,)
+ """
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats, feats_lengths = speech, speech_lengths
+
+ return feats, feats_lengths
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
+ """
+ if self.criterion_transducer is None:
+ try:
+ # from warprnnt_pytorch import RNNTLoss
+ # self.criterion_transducer = RNNTLoss(
+ # reduction="mean",
+ # fastemit_lambda=self.fastemit_lambda,
+ # )
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ # loss_transducer = self.criterion_transducer(
+ # joint_out,
+ # target,
+ # t_len,
+ # u_len,
+ # )
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
+ if not self.training and (self.report_cer or self.report_wer):
+ if self.error_calculator is None:
+ self.error_calculator = ErrorCalculator(
+ self.decoder,
+ self.joint_network,
+ self.token_list,
+ self.sym_space,
+ self.sym_blank,
+ report_cer=self.report_cer,
+ report_wer=self.report_wer,
+ )
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+ Return:
+ loss_ctc: CTC loss value.
+ """
+ ctc_in = self.ctc_lin(
+ torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+ )
+ ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+ target_mask = target != 0
+ ctc_target = target[target_mask].cpu()
+
+ with torch.backends.cudnn.flags(deterministic=True):
+ loss_ctc = torch.nn.functional.ctc_loss(
+ ctc_in,
+ ctc_target,
+ t_len,
+ u_len,
+ zero_infinity=True,
+ reduction="sum",
+ )
+ loss_ctc /= target.size(0)
+
+ return loss_ctc
+
+ def _calc_lm_loss(
+ self,
+ decoder_out: torch.Tensor,
+ target: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute LM loss.
+ Args:
+ decoder_out: Decoder output sequences. (B, U, D_dec)
+ target: Target label ID sequences. (B, L)
+ Return:
+ loss_lm: LM loss value.
+ """
+ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+ lm_target = target.view(-1).type(torch.int64)
+
+ with torch.no_grad():
+ true_dist = lm_loss_in.clone()
+ true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+ # Ignore blank ID (0)
+ ignore = lm_target == 0
+ lm_target = lm_target.masked_fill(ignore, 0)
+
+ true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+ loss_lm = torch.nn.functional.kl_div(
+ torch.log_softmax(lm_loss_in, dim=1),
+ true_dist,
+ reduction="none",
+ )
+ loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+ 0
+ )
+
+ return loss_lm
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
+ ys_pad = torch.cat(
+ [
+ self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
+ ys_pad,
+ ],
+ dim=1,
+ )
+ ys_pad_lens += 1
+
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.att_decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ return loss_att, acc_att
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index de669f2..3f7011d 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -36,8 +36,12 @@
class DiarSondModel(AbsESPnetModel):
- """Speaker overlap-aware neural diarization model
- reference: https://arxiv.org/abs/2211.10243
+ """
+ Author: Speech Lab, Alibaba Group, China
+ SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+ https://arxiv.org/abs/2211.10243
+ TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+ https://arxiv.org/abs/2303.05397
"""
def __init__(
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index eff63d9..5b21277 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -1,3 +1,7 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+"""
+
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index 887439c..d1367ab 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -32,7 +32,7 @@
class TimestampPredictor(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index ac4db32..ca76244 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -40,7 +40,7 @@
class UniASR(AbsESPnetModel):
"""
- Author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index ff37429..50ec475 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -35,6 +35,11 @@
class VADXOptions:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(
self,
sample_rate: int = 16000,
@@ -99,6 +104,11 @@
class E2EVadSpeechBufWithDoa(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self):
self.start_ms = 0
self.end_ms = 0
@@ -117,6 +127,11 @@
class E2EVadFrameProb(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
@@ -126,6 +141,11 @@
class WindowDetector(object):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
@@ -192,6 +212,11 @@
class E2EVadModel(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
@@ -460,8 +485,8 @@
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
- i].contain_seg_end_point:
+ if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..9777cee 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -8,6 +8,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
+from typing import Dict
import torch
from torch import nn
@@ -18,6 +19,7 @@
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
+ RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.modules.embedding import (
@@ -25,16 +27,23 @@
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
+ StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import (
+ TooShortUttError,
+ check_short_utt,
+ make_chunk_mask,
+ make_source_mask,
+)
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.modules.repeat import repeat
+from funasr.modules.repeat import repeat, MultiBlocks
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
@@ -42,6 +51,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
+from funasr.modules.subsampling import StreamingConvInput
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@@ -275,6 +286,188 @@
return (x, pos_emb), mask
return x, mask
+
+class ChunkEncoderLayer(torch.nn.Module):
+ """Chunk Conformer module definition.
+ Args:
+ block_size: Input/output size.
+ self_att: Self-attention module instance.
+ feed_forward: Feed-forward module instance.
+ feed_forward_macaron: Feed-forward module instance for macaron network.
+ conv_mod: Convolution module instance.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ block_size: int,
+ self_att: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: torch.nn.Module,
+ conv_mod: torch.nn.Module,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_args: Dict = {},
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a Conformer object."""
+ super().__init__()
+
+ self.self_att = self_att
+
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.feed_forward_scale = 0.5
+
+ self.conv_mod = conv_mod
+
+ self.norm_feed_forward = norm_class(block_size, **norm_args)
+ self.norm_self_att = norm_class(block_size, **norm_args)
+
+ self.norm_macaron = norm_class(block_size, **norm_args)
+ self.norm_conv = norm_class(block_size, **norm_args)
+ self.norm_final = norm_class(block_size, **norm_args)
+
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.block_size = block_size
+ self.cache = None
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset self-attention and convolution modules cache for streaming.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ self.cache = [
+ torch.zeros(
+ (1, left_context, self.block_size),
+ device=device,
+ ),
+ torch.zeros(
+ (
+ 1,
+ self.block_size,
+ self.conv_mod.kernel_size - 1,
+ ),
+ device=device,
+ ),
+ ]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ mask: Source mask. (B, T)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.dropout(
+ self.feed_forward_macaron(x)
+ )
+
+ residual = x
+ x = self.norm_self_att(x)
+ x_q = x
+ x = residual + self.dropout(
+ self.self_att(
+ x_q,
+ x,
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ )
+
+ residual = x
+
+ x = self.norm_conv(x)
+ x, _ = self.conv_mod(x)
+ x = residual + self.dropout(x)
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+ x = self.norm_final(x)
+ return x, mask, pos_enc
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode chunk of input sequence.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+ residual = x
+ x = self.norm_self_att(x)
+ if left_context > 0:
+ key = torch.cat([self.cache[0], x], dim=1)
+ else:
+ key = x
+ val = key
+
+ if right_context > 0:
+ att_cache = key[:, -(left_context + right_context) : -right_context, :]
+ else:
+ att_cache = key[:, -left_context:, :]
+ x = residual + self.self_att(
+ x,
+ key,
+ val,
+ pos_enc,
+ mask,
+ left_context=left_context,
+ )
+
+ residual = x
+ x = self.norm_conv(x)
+ x, conv_cache = self.conv_mod(
+ x, cache=self.cache[1], right_context=right_context
+ )
+ x = residual + x
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+ x = self.norm_final(x)
+ self.cache = [att_cache, conv_cache]
+
+ return x, pos_enc
class ConformerEncoder(AbsEncoder):
@@ -604,3 +797,442 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+
+
+class CausalConvolution(torch.nn.Module):
+ """ConformerConvolution module definition.
+ Args:
+ channels: The number of channels.
+ kernel_size: Size of the convolving kernel.
+ activation: Type of activation function.
+ norm_args: Normalization module arguments.
+ causal: Whether to use causal convolution (set to True if streaming).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ norm_args: Dict = {},
+ causal: bool = False,
+ ) -> None:
+ """Construct an ConformerConvolution object."""
+ super().__init__()
+
+ assert (kernel_size - 1) % 2 == 0
+
+ self.kernel_size = kernel_size
+
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ if causal:
+ self.lorder = kernel_size - 1
+ padding = 0
+ else:
+ self.lorder = 0
+ padding = (kernel_size - 1) // 2
+
+ self.depthwise_conv = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ )
+ self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cache: Optional[torch.Tensor] = None,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x: ConformerConvolution input sequences. (B, T, D_hidden)
+ cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+ right_context: Number of frames in right context.
+ Returns:
+ x: ConformerConvolution output sequences. (B, T, D_hidden)
+ cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+ """
+ x = self.pointwise_conv1(x.transpose(1, 2))
+ x = torch.nn.functional.glu(x, dim=1)
+
+ if self.lorder > 0:
+ if cache is None:
+ x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ x = torch.cat([cache, x], dim=2)
+
+ if right_context > 0:
+ cache = x[:, :, -(self.lorder + right_context) : -right_context]
+ else:
+ cache = x[:, :, -self.lorder :]
+
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x).transpose(1, 2)
+
+ return x, cache
+
+class ConformerChunkEncoder(AbsEncoder):
+ """Encoder module definition.
+ Args:
+ input_size: Input size.
+ body_conf: Encoder body configuration.
+ input_conf: Encoder input configuration.
+ main_conf: Encoder main configuration.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ embed_vgg_like: bool = False,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ norm_type: str = "layer_norm",
+ cnn_module_kernel: int = 31,
+ conv_mod_norm_eps: float = 0.00001,
+ conv_mod_norm_momentum: float = 0.1,
+ simplified_att_score: bool = False,
+ dynamic_chunk_training: bool = False,
+ short_chunk_threshold: float = 0.75,
+ short_chunk_size: int = 25,
+ left_chunk_size: int = 0,
+ time_reduction_factor: int = 1,
+ unified_model_training: bool = False,
+ default_chunk_size: int = 16,
+ jitter_range: int = 4,
+ subsampling_factor: int = 1,
+ ) -> None:
+ """Construct an Encoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ self.embed = StreamingConvInput(
+ input_size,
+ output_size,
+ subsampling_factor,
+ vgg_like=embed_vgg_like,
+ output_size=output_size,
+ )
+
+ self.pos_enc = StreamingRelPositionalEncoding(
+ output_size,
+ positional_dropout_rate,
+ )
+
+ activation = get_activation(
+ activation_type
+ )
+
+ pos_wise_args = (
+ output_size,
+ linear_units,
+ positional_dropout_rate,
+ activation,
+ )
+
+ conv_mod_norm_args = {
+ "eps": conv_mod_norm_eps,
+ "momentum": conv_mod_norm_momentum,
+ }
+
+ conv_mod_args = (
+ output_size,
+ cnn_module_kernel,
+ activation,
+ conv_mod_norm_args,
+ dynamic_chunk_training or unified_model_training,
+ )
+
+ mult_att_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ simplified_att_score,
+ )
+
+
+ fn_modules = []
+ for _ in range(num_blocks):
+ module = lambda: ChunkEncoderLayer(
+ output_size,
+ RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ CausalConvolution(*conv_mod_args),
+ dropout_rate=dropout_rate,
+ )
+ fn_modules.append(module)
+
+ self.encoders = MultiBlocks(
+ [fn() for fn in fn_modules],
+ output_size,
+ )
+
+ self._output_size = output_size
+
+ self.dynamic_chunk_training = dynamic_chunk_training
+ self.short_chunk_threshold = short_chunk_threshold
+ self.short_chunk_size = short_chunk_size
+ self.left_chunk_size = left_chunk_size
+
+ self.unified_model_training = unified_model_training
+ self.default_chunk_size = default_chunk_size
+ self.jitter_range = jitter_range
+
+ self.time_reduction_factor = time_reduction_factor
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ hop_length: Frontend's hop length
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size) * hop_length
+
+ def get_encoder_input_size(self, size: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size)
+
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of frames in left context.
+ device: Device ID.
+ """
+ return self.encoders.reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ if self.unified_model_training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+ x_chunk = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x_utt, x_chunk, olens
+
+ elif self.dynamic_chunk_training:
+ max_len = x.size(1)
+ chunk_size = torch.randint(1, max_len, (1,)).item()
+
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ else:
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = None
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x, olens
+
+ def simu_chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ processed_frames: torch.tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Encode input sequences as chunks.
+ Args:
+ x: Encoder input features. (1, T_in, F)
+ x_len: Encoder input features lengths. (1,)
+ processed_frames: Number of frames already seen.
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ """
+ mask = make_source_mask(x_len)
+ x, mask = self.embed(x, mask, None)
+
+ if left_context > 0:
+ processed_mask = (
+ torch.arange(left_context, device=x.device)
+ .view(1, left_context)
+ .flip(1)
+ )
+ processed_mask = processed_mask >= processed_frames
+ mask = torch.cat([processed_mask, mask], dim=1)
+ pos_enc = self.pos_enc(x, left_context=left_context)
+ x = self.encoders.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ if right_context > 0:
+ x = x[:, 0:-right_context, :]
+
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ return x
diff --git a/funasr/models/encoder/opennmt_encoders/conv_encoder.py b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
index a33e0b7..eec854f 100644
--- a/funasr/models/encoder/opennmt_encoders/conv_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
@@ -67,7 +67,7 @@
class ConvEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Convolution encoder in OpenNMT framework
"""
diff --git a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
index cf77bce..db30f08 100644
--- a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
+++ b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
@@ -117,7 +117,7 @@
class SelfAttentionEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
Self attention encoder in OpenNMT framework
"""
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 7d7179a..93695c8 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -406,6 +406,12 @@
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+ https://arxiv.org/abs/2211.10243
+ """
+
super(ResNet34Diar, self).__init__(
input_size,
use_head_conv=use_head_conv,
@@ -633,6 +639,12 @@
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+ https://arxiv.org/abs/2303.05397
+ """
+
super(ResNet34SpL2RegDiar, self).__init__(
input_size,
use_head_conv=use_head_conv,
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a3a353..f2502bb 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -11,7 +11,7 @@
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
@@ -117,7 +117,7 @@
class SANMEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -180,6 +180,8 @@
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -357,7 +359,7 @@
if self.embed is None:
xs_pad = xs_pad
else:
- xs_pad = self.embed.forward_chunk(xs_pad, cache)
+ xs_pad = self.embed(xs_pad, cache)
encoder_outs = self.encoders0(xs_pad, None, None, None, None)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
@@ -549,7 +551,7 @@
class SANMEncoderChunkOpt(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -962,7 +964,7 @@
class SANMVadEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 475a939..203f00e 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -38,7 +38,7 @@
return cmvn
-def apply_cmvn(inputs, cmvn_file): # noqa
+def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
@@ -47,7 +47,6 @@
dtype = inputs.dtype
frame, dim = inputs.shape
- cmvn = load_cmvn(cmvn_file)
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
@@ -111,6 +110,7 @@
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -140,8 +140,8 @@
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -194,8 +194,8 @@
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ if self.cmvn is not None:
+ mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
diff --git a/funasr/models/joint_net/joint_network.py b/funasr/models/joint_net/joint_network.py
new file mode 100644
index 0000000..ed827c4
--- /dev/null
+++ b/funasr/models/joint_net/joint_network.py
@@ -0,0 +1,61 @@
+"""Transducer joint network implementation."""
+
+import torch
+
+from funasr.modules.nets_utils import get_activation
+
+
+class JointNetwork(torch.nn.Module):
+ """Transducer joint network module.
+
+ Args:
+ output_size: Output size.
+ encoder_size: Encoder output size.
+ decoder_size: Decoder output size..
+ joint_space_size: Joint space size.
+ joint_act_type: Type of activation for joint network.
+ **activation_parameters: Parameters for the activation function.
+
+ """
+
+ def __init__(
+ self,
+ output_size: int,
+ encoder_size: int,
+ decoder_size: int,
+ joint_space_size: int = 256,
+ joint_activation_type: str = "tanh",
+ ) -> None:
+ """Construct a JointNetwork object."""
+ super().__init__()
+
+ self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
+ self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
+
+ self.lin_out = torch.nn.Linear(joint_space_size, output_size)
+
+ self.joint_activation = get_activation(
+ joint_activation_type
+ )
+
+ def forward(
+ self,
+ enc_out: torch.Tensor,
+ dec_out: torch.Tensor,
+ project_input: bool = True,
+ ) -> torch.Tensor:
+ """Joint computation of encoder and decoder hidden state sequences.
+
+ Args:
+ enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
+ dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
+
+ Returns:
+ joint_out: Joint output state sequences. (B, T, U, D_out)
+
+ """
+ if project_input:
+ joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
+ else:
+ joint_out = self.joint_activation(enc_out + dec_out)
+ return self.lin_out(joint_out)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index e80a915..a5273f8 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -234,6 +234,7 @@
last_fire_place = len_time - 1
last_fire_remainds = 0.0
pre_alphas_length = 0
+ last_fire = False
mask_chunk_peak_predictor = None
if cache is not None:
@@ -251,10 +252,15 @@
if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
last_fire_place = len_time - 1 - i
last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
+ last_fire = True
break
- last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
- cache["cif_hidden"] = hidden[:, last_fire_place:, :]
- cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
+ if last_fire:
+ last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
+ cache["cif_hidden"] = hidden[:, last_fire_place:, :]
+ cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
+ else:
+ cache["cif_hidden"] = hidden
+ cache["cif_alphas"] = alphas
token_num_int = token_num.floor().type(torch.int32).item()
return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
diff --git a/funasr/models/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
index 84a2e6c..e893c65 100644
--- a/funasr/models/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -13,7 +13,11 @@
class TargetDelayTransformer(AbsPunctuation):
-
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(
self,
vocab_size: int,
diff --git a/funasr/models/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
index 66f7fad..fe298ce 100644
--- a/funasr/models/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -11,7 +11,11 @@
class VadRealtimeTransformer(AbsPunctuation):
-
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(
self,
vocab_size: int,
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 31d5a87..6202079 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -11,7 +11,7 @@
import numpy
import torch
from torch import nn
-
+from typing import Optional, Tuple
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -741,3 +741,221 @@
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+ """RelPositionMultiHeadedAttention definition.
+ Args:
+ num_heads: Number of attention heads.
+ embed_size: Embedding size.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ embed_size: int,
+ dropout_rate: float = 0.0,
+ simplified_attention_score: bool = False,
+ ) -> None:
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+
+ self.d_k = embed_size // num_heads
+ self.num_heads = num_heads
+
+ assert self.d_k * num_heads == embed_size, (
+ "embed_size (%d) must be divisible by num_heads (%d)",
+ (embed_size, num_heads),
+ )
+
+ self.linear_q = torch.nn.Linear(embed_size, embed_size)
+ self.linear_k = torch.nn.Linear(embed_size, embed_size)
+ self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+ self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+ if simplified_attention_score:
+ self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+ self.compute_att_score = self.compute_simplified_attention_score
+ else:
+ self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+ self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ self.compute_att_score = self.compute_attention_score
+
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.attn = None
+
+ def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute relative positional encoding.
+ Args:
+ x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+ left_context: Number of frames in left context.
+ Returns:
+ x: Output sequence. (B, H, T_1, T_2)
+ """
+ batch_size, n_heads, time1, n = x.shape
+ time2 = time1 + left_context
+
+ batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+ return x.as_strided(
+ (batch_size, n_heads, time1, time2),
+ (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+ storage_offset=(n_stride * (time1 - 1)),
+ )
+
+ def compute_simplified_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Simplified attention score computation.
+ Reference: https://github.com/k2-fsa/icefall/pull/458
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ pos_enc = self.linear_pos(pos_enc)
+
+ matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+ matrix_bd = self.rel_shift(
+ pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+ left_context=left_context,
+ )
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def compute_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Attention score computation.
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+ matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ v: Value tensor. (B, T_2, size)
+ Returns:
+ q: Transformed query tensor. (B, H, T_1, d_k)
+ k: Transformed key tensor. (B, H, T_2, d_k)
+ v: Transformed value tensor. (B, H, T_2, d_k)
+ """
+ n_batch = query.size(0)
+
+ q = (
+ self.linear_q(query)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ k = (
+ self.linear_k(key)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ v = (
+ self.linear_v(value)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ value: Transformed value. (B, H, T_2, d_k)
+ scores: Attention score. (B, H, T_1, T_2)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ Returns:
+ attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+ """
+ batch_size = scores.size(0)
+ mask = mask.unsqueeze(1).unsqueeze(2)
+ if chunk_mask is not None:
+ mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+ scores = scores.masked_fill(mask, float("-inf"))
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+ attn_output = self.dropout(self.attn)
+ attn_output = torch.matmul(attn_output, value)
+
+ attn_output = self.linear_out(
+ attn_output.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.num_heads * self.d_k)
+ )
+
+ return attn_output
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Compute scaled dot product attention with rel. positional encoding.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ value: Value tensor. (B, T_2, size)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ left_context: Number of frames in left context.
+ Returns:
+ : Output tensor. (B, T_1, H * d_k)
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+ return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
new file mode 100644
index 0000000..3eb8e08
--- /dev/null
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -0,0 +1,704 @@
+"""Search algorithms for Transducer models."""
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from funasr.models.joint_net.joint_network import JointNetwork
+
+
+@dataclass
+class Hypothesis:
+ """Default hypothesis definition for Transducer search algorithms.
+
+ Args:
+ score: Total log-probability.
+ yseq: Label sequence as integer ID sequence.
+ dec_state: RNNDecoder or StatelessDecoder state.
+ ((N, 1, D_dec), (N, 1, D_dec) or None) or None
+ lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
+
+ """
+
+ score: float
+ yseq: List[int]
+ dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
+ lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
+
+
+@dataclass
+class ExtendedHypothesis(Hypothesis):
+ """Extended hypothesis definition for NSC beam search and mAES.
+
+ Args:
+ : Hypothesis dataclass arguments.
+ dec_out: Decoder output sequence. (B, D_dec)
+ lm_score: Log-probabilities of the LM for given label. (vocab_size)
+
+ """
+
+ dec_out: torch.Tensor = None
+ lm_score: torch.Tensor = None
+
+
+class BeamSearchTransducer:
+ """Beam search implementation for Transducer.
+
+ Args:
+ decoder: Decoder module.
+ joint_network: Joint network module.
+ beam_size: Size of the beam.
+ lm: LM class.
+ lm_weight: LM weight for soft fusion.
+ search_type: Search algorithm to use during inference.
+ max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
+ u_max: Maximum expected target sequence length. (ALSD)
+ nstep: Number of maximum expansion steps at each time step. (mAES)
+ expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
+ expansion_beta:
+ Number of additional candidates for expanded hypotheses selection. (mAES)
+ score_norm: Normalize final scores by length.
+ nbest: Number of final hypothesis.
+ streaming: Whether to perform chunk-by-chunk beam search.
+
+ """
+
+ def __init__(
+ self,
+ decoder,
+ joint_network: JointNetwork,
+ beam_size: int,
+ lm: Optional[torch.nn.Module] = None,
+ lm_weight: float = 0.1,
+ search_type: str = "default",
+ max_sym_exp: int = 3,
+ u_max: int = 50,
+ nstep: int = 2,
+ expansion_gamma: float = 2.3,
+ expansion_beta: int = 2,
+ score_norm: bool = False,
+ nbest: int = 1,
+ streaming: bool = False,
+ ) -> None:
+ """Construct a BeamSearchTransducer object."""
+ super().__init__()
+
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.vocab_size = decoder.vocab_size
+
+ assert beam_size <= self.vocab_size, (
+ "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
+ % (
+ beam_size,
+ self.vocab_size,
+ )
+ )
+ self.beam_size = beam_size
+
+ if search_type == "default":
+ self.search_algorithm = self.default_beam_search
+ elif search_type == "tsd":
+ assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
+ max_sym_exp
+ )
+ self.max_sym_exp = max_sym_exp
+
+ self.search_algorithm = self.time_sync_decoding
+ elif search_type == "alsd":
+ assert not streaming, "ALSD is not available in streaming mode."
+
+ assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
+ self.u_max = u_max
+
+ self.search_algorithm = self.align_length_sync_decoding
+ elif search_type == "maes":
+ assert self.vocab_size >= beam_size + expansion_beta, (
+ "beam_size (%d) + expansion_beta (%d) "
+ " should be smaller than or equal to vocab size (%d)."
+ % (beam_size, expansion_beta, self.vocab_size)
+ )
+ self.max_candidates = beam_size + expansion_beta
+
+ self.nstep = nstep
+ self.expansion_gamma = expansion_gamma
+
+ self.search_algorithm = self.modified_adaptive_expansion_search
+ else:
+ raise NotImplementedError(
+ "Specified search type (%s) is not supported." % search_type
+ )
+
+ self.use_lm = lm is not None
+
+ if self.use_lm:
+ assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
+
+ self.sos = self.vocab_size - 1
+
+ self.lm = lm
+ self.lm_weight = lm_weight
+
+ self.score_norm = score_norm
+ self.nbest = nbest
+
+ self.reset_inference_cache()
+
+ def __call__(
+ self,
+ enc_out: torch.Tensor,
+ is_final: bool = True,
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ enc_out: Encoder output sequence. (T, D_enc)
+ is_final: Whether enc_out is the final chunk of data.
+
+ Returns:
+ nbest_hyps: N-best decoding results
+
+ """
+ self.decoder.set_device(enc_out.device)
+
+ hyps = self.search_algorithm(enc_out)
+
+ if is_final:
+ self.reset_inference_cache()
+
+ return self.sort_nbest(hyps)
+
+ self.search_cache = hyps
+
+ return hyps
+
+ def reset_inference_cache(self) -> None:
+ """Reset cache for decoder scoring and streaming."""
+ self.decoder.score_cache = {}
+ self.search_cache = None
+
+ def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+ """Sort in-place hypotheses by score or score given sequence length.
+
+ Args:
+ hyps: Hypothesis.
+
+ Return:
+ hyps: Sorted hypothesis.
+
+ """
+ if self.score_norm:
+ hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
+ else:
+ hyps.sort(key=lambda x: x.score, reverse=True)
+
+ return hyps[: self.nbest]
+
+ def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+ """Recombine hypotheses with same label ID sequence.
+
+ Args:
+ hyps: Hypotheses.
+
+ Returns:
+ final: Recombined hypotheses.
+
+ """
+ final = {}
+
+ for hyp in hyps:
+ str_yseq = "_".join(map(str, hyp.yseq))
+
+ if str_yseq in final:
+ final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
+ else:
+ final[str_yseq] = hyp
+
+ return [*final.values()]
+
+ def select_k_expansions(
+ self,
+ hyps: List[ExtendedHypothesis],
+ topk_idx: torch.Tensor,
+ topk_logp: torch.Tensor,
+ ) -> List[ExtendedHypothesis]:
+ """Return K hypotheses candidates for expansion from a list of hypothesis.
+
+ K candidates are selected according to the extended hypotheses probabilities
+ and a prune-by-value method. Where K is equal to beam_size + beta.
+
+ Args:
+ hyps: Hypotheses.
+ topk_idx: Indices of candidates hypothesis.
+ topk_logp: Log-probabilities of candidates hypothesis.
+
+ Returns:
+ k_expansions: Best K expansion hypotheses candidates.
+
+ """
+ k_expansions = []
+
+ for i, hyp in enumerate(hyps):
+ hyp_i = [
+ (int(k), hyp.score + float(v))
+ for k, v in zip(topk_idx[i], topk_logp[i])
+ ]
+ k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
+
+ k_expansions.append(
+ sorted(
+ filter(
+ lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
+ ),
+ key=lambda x: x[1],
+ reverse=True,
+ )
+ )
+
+ return k_expansions
+
+ def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
+ """Make batch of inputs with left padding for LM scoring.
+
+ Args:
+ hyps_seq: Hypothesis sequences.
+
+ Returns:
+ : Padded batch of sequences.
+
+ """
+ max_len = max([len(h) for h in hyps_seq])
+
+ return torch.LongTensor(
+ [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
+ device=self.decoder.device,
+ )
+
+ def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
+ """Beam search implementation without prefix search.
+
+ Modified from https://arxiv.org/pdf/1211.3711.pdf
+
+ Args:
+ enc_out: Encoder output sequence. (T, D)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ beam_k = min(self.beam_size, (self.vocab_size - 1))
+ max_t = len(enc_out)
+
+ if self.search_cache is not None:
+ kept_hyps = self.search_cache
+ else:
+ kept_hyps = [
+ Hypothesis(
+ score=0.0,
+ yseq=[0],
+ dec_state=self.decoder.init_state(1),
+ )
+ ]
+
+ for t in range(max_t):
+ hyps = kept_hyps
+ kept_hyps = []
+
+ while True:
+ max_hyp = max(hyps, key=lambda x: x.score)
+ hyps.remove(max_hyp)
+
+ label = torch.full(
+ (1, 1),
+ max_hyp.yseq[-1],
+ dtype=torch.long,
+ device=self.decoder.device,
+ )
+ dec_out, state = self.decoder.score(
+ label,
+ max_hyp.yseq,
+ max_hyp.dec_state,
+ )
+
+ logp = torch.log_softmax(
+ self.joint_network(enc_out[t : t + 1, :], dec_out),
+ dim=-1,
+ ).squeeze(0)
+ top_k = logp[1:].topk(beam_k, dim=-1)
+
+ kept_hyps.append(
+ Hypothesis(
+ score=(max_hyp.score + float(logp[0:1])),
+ yseq=max_hyp.yseq,
+ dec_state=max_hyp.dec_state,
+ lm_state=max_hyp.lm_state,
+ )
+ )
+
+ if self.use_lm:
+ lm_scores, lm_state = self.lm.score(
+ torch.LongTensor(
+ [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
+ ),
+ max_hyp.lm_state,
+ None,
+ )
+ else:
+ lm_state = max_hyp.lm_state
+
+ for logp, k in zip(*top_k):
+ score = max_hyp.score + float(logp)
+
+ if self.use_lm:
+ score += self.lm_weight * lm_scores[k + 1]
+
+ hyps.append(
+ Hypothesis(
+ score=score,
+ yseq=max_hyp.yseq + [int(k + 1)],
+ dec_state=state,
+ lm_state=lm_state,
+ )
+ )
+
+ hyps_max = float(max(hyps, key=lambda x: x.score).score)
+ kept_most_prob = sorted(
+ [hyp for hyp in kept_hyps if hyp.score > hyps_max],
+ key=lambda x: x.score,
+ )
+ if len(kept_most_prob) >= self.beam_size:
+ kept_hyps = kept_most_prob
+ break
+
+ return kept_hyps
+
+ def align_length_sync_decoding(
+ self,
+ enc_out: torch.Tensor,
+ ) -> List[Hypothesis]:
+ """Alignment-length synchronous beam search implementation.
+
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ h: Encoder output sequences. (T, D)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ t_max = int(enc_out.size(0))
+ u_max = min(self.u_max, (t_max - 1))
+
+ B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
+ final = []
+
+ if self.use_lm:
+ B[0].lm_state = self.lm.zero_state()
+
+ for i in range(t_max + u_max):
+ A = []
+
+ B_ = []
+ B_enc_out = []
+ for hyp in B:
+ u = len(hyp.yseq) - 1
+ t = i - u
+
+ if t > (t_max - 1):
+ continue
+
+ B_.append(hyp)
+ B_enc_out.append((t, enc_out[t]))
+
+ if B_:
+ beam_enc_out = torch.stack([b[1] for b in B_enc_out])
+ beam_dec_out, beam_state = self.decoder.batch_score(B_)
+
+ beam_logp = torch.log_softmax(
+ self.joint_network(beam_enc_out, beam_dec_out),
+ dim=-1,
+ )
+ beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
+
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([b.yseq for b in B_]),
+ [b.lm_state for b in B_],
+ None,
+ )
+
+ for i, hyp in enumerate(B_):
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(beam_logp[i, 0])),
+ yseq=hyp.yseq[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+
+ A.append(new_hyp)
+
+ if B_enc_out[i][0] == (t_max - 1):
+ final.append(new_hyp)
+
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ yseq=(hyp.yseq[:] + [int(k)]),
+ dec_state=self.decoder.select_state(beam_state, i),
+ lm_state=hyp.lm_state,
+ )
+
+ if self.use_lm:
+ new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
+ new_hyp.lm_state = beam_lm_states[i]
+
+ A.append(new_hyp)
+
+ B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
+ B = self.recombine_hyps(B)
+
+ if final:
+ return final
+
+ return B
+
+ def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
+ """Time synchronous beam search implementation.
+
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ enc_out: Encoder output sequence. (T, D)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ if self.search_cache is not None:
+ B = self.search_cache
+ else:
+ B = [
+ Hypothesis(
+ yseq=[0],
+ score=0.0,
+ dec_state=self.decoder.init_state(1),
+ )
+ ]
+
+ if self.use_lm:
+ B[0].lm_state = self.lm.zero_state()
+
+ for enc_out_t in enc_out:
+ A = []
+ C = B
+
+ enc_out_t = enc_out_t.unsqueeze(0)
+
+ for v in range(self.max_sym_exp):
+ D = []
+
+ beam_dec_out, beam_state = self.decoder.batch_score(C)
+
+ beam_logp = torch.log_softmax(
+ self.joint_network(enc_out_t, beam_dec_out),
+ dim=-1,
+ )
+ beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
+
+ seq_A = [h.yseq for h in A]
+
+ for i, hyp in enumerate(C):
+ if hyp.yseq not in seq_A:
+ A.append(
+ Hypothesis(
+ score=(hyp.score + float(beam_logp[i, 0])),
+ yseq=hyp.yseq[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+ )
+ else:
+ dict_pos = seq_A.index(hyp.yseq)
+
+ A[dict_pos].score = np.logaddexp(
+ A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
+ )
+
+ if v < (self.max_sym_exp - 1):
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([c.yseq for c in C]),
+ [c.lm_state for c in C],
+ None,
+ )
+
+ for i, hyp in enumerate(C):
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ yseq=(hyp.yseq + [int(k)]),
+ dec_state=self.decoder.select_state(beam_state, i),
+ lm_state=hyp.lm_state,
+ )
+
+ if self.use_lm:
+ new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
+ new_hyp.lm_state = beam_lm_states[i]
+
+ D.append(new_hyp)
+
+ C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
+
+ B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
+
+ return B
+
+ def modified_adaptive_expansion_search(
+ self,
+ enc_out: torch.Tensor,
+ ) -> List[ExtendedHypothesis]:
+ """Modified version of Adaptive Expansion Search (mAES).
+
+ Based on AES (https://ieeexplore.ieee.org/document/9250505) and
+ NSC (https://arxiv.org/abs/2201.05420).
+
+ Args:
+ enc_out: Encoder output sequence. (T, D_enc)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ if self.search_cache is not None:
+ kept_hyps = self.search_cache
+ else:
+ init_tokens = [
+ ExtendedHypothesis(
+ yseq=[0],
+ score=0.0,
+ dec_state=self.decoder.init_state(1),
+ )
+ ]
+
+ beam_dec_out, beam_state = self.decoder.batch_score(
+ init_tokens,
+ )
+
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
+ [h.lm_state for h in init_tokens],
+ None,
+ )
+
+ lm_state = beam_lm_states[0]
+ lm_score = beam_lm_scores[0]
+ else:
+ lm_state = None
+ lm_score = None
+
+ kept_hyps = [
+ ExtendedHypothesis(
+ yseq=[0],
+ score=0.0,
+ dec_state=self.decoder.select_state(beam_state, 0),
+ dec_out=beam_dec_out[0],
+ lm_state=lm_state,
+ lm_score=lm_score,
+ )
+ ]
+
+ for enc_out_t in enc_out:
+ hyps = kept_hyps
+ kept_hyps = []
+
+ beam_enc_out = enc_out_t.unsqueeze(0)
+
+ list_b = []
+ for n in range(self.nstep):
+ beam_dec_out = torch.stack([h.dec_out for h in hyps])
+
+ beam_logp, beam_idx = torch.log_softmax(
+ self.joint_network(beam_enc_out, beam_dec_out),
+ dim=-1,
+ ).topk(self.max_candidates, dim=-1)
+
+ k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
+
+ list_exp = []
+ for i, hyp in enumerate(hyps):
+ for k, new_score in k_expansions[i]:
+ new_hyp = ExtendedHypothesis(
+ yseq=hyp.yseq[:],
+ score=new_score,
+ dec_out=hyp.dec_out,
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ lm_score=hyp.lm_score,
+ )
+
+ if k == 0:
+ list_b.append(new_hyp)
+ else:
+ new_hyp.yseq.append(int(k))
+
+ if self.use_lm:
+ new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
+
+ list_exp.append(new_hyp)
+
+ if not list_exp:
+ kept_hyps = sorted(
+ self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
+ )[: self.beam_size]
+
+ break
+ else:
+ beam_dec_out, beam_state = self.decoder.batch_score(
+ list_exp,
+ )
+
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([h.yseq for h in list_exp]),
+ [h.lm_state for h in list_exp],
+ None,
+ )
+
+ if n < (self.nstep - 1):
+ for i, hyp in enumerate(list_exp):
+ hyp.dec_out = beam_dec_out[i]
+ hyp.dec_state = self.decoder.select_state(beam_state, i)
+
+ if self.use_lm:
+ hyp.lm_state = beam_lm_states[i]
+ hyp.lm_score = beam_lm_scores[i]
+
+ hyps = list_exp[:]
+ else:
+ beam_logp = torch.log_softmax(
+ self.joint_network(beam_enc_out, beam_dec_out),
+ dim=-1,
+ )
+
+ for i, hyp in enumerate(list_exp):
+ hyp.score += float(beam_logp[i, 0])
+
+ hyp.dec_out = beam_dec_out[i]
+ hyp.dec_state = self.decoder.select_state(beam_state, i)
+
+ if self.use_lm:
+ hyp.lm_state = beam_lm_states[i]
+ hyp.lm_score = beam_lm_scores[i]
+
+ kept_hyps = sorted(
+ self.recombine_hyps(list_b + list_exp),
+ key=lambda x: x.score,
+ reverse=True,
+ )[: self.beam_size]
+
+ return kept_hyps
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 92f9079..f430fcb 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -6,6 +6,8 @@
"""Common functions for ASR."""
+from typing import List, Optional, Tuple
+
import json
import logging
import sys
@@ -13,7 +15,10 @@
from itertools import groupby
import numpy as np
import six
+import torch
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.models.joint_net.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
@@ -247,3 +252,148 @@
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
+
+class ErrorCalculatorTransducer:
+ """Calculate CER and WER for transducer models.
+ Args:
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ token_list: List of token units.
+ sym_space: Space symbol.
+ sym_blank: Blank symbol.
+ report_cer: Whether to compute CER.
+ report_wer: Whether to compute WER.
+ """
+
+ def __init__(
+ self,
+ decoder,
+ joint_network: JointNetwork,
+ token_list: List[int],
+ sym_space: str,
+ sym_blank: str,
+ report_cer: bool = False,
+ report_wer: bool = False,
+ ) -> None:
+ """Construct an ErrorCalculatorTransducer object."""
+ super().__init__()
+
+ self.beam_search = BeamSearchTransducer(
+ decoder=decoder,
+ joint_network=joint_network,
+ beam_size=1,
+ search_type="default",
+ score_norm=False,
+ )
+
+ self.decoder = decoder
+
+ self.token_list = token_list
+ self.space = sym_space
+ self.blank = sym_blank
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ def __call__(
+ self, encoder_out: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[Optional[float], Optional[float]]:
+ """Calculate sentence-level WER or/and CER score for Transducer model.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ Returns:
+ : Sentence-level CER score.
+ : Sentence-level WER score.
+ """
+ cer, wer = None, None
+
+ batchsize = int(encoder_out.size(0))
+
+ encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
+
+ batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+ pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
+
+ char_pred, char_target = self.convert_to_char(pred, target)
+
+ if self.report_cer:
+ cer = self.calculate_cer(char_pred, char_target)
+
+ if self.report_wer:
+ wer = self.calculate_wer(char_pred, char_target)
+
+ return cer, wer
+
+ def convert_to_char(
+ self, pred: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[List, List]:
+ """Convert label ID sequences to character sequences.
+ Args:
+ pred: Prediction label ID sequences. (B, U)
+ target: Target label ID sequences. (B, L)
+ Returns:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ """
+ char_pred, char_target = [], []
+
+ for i, pred_i in enumerate(pred):
+ char_pred_i = [self.token_list[int(h)] for h in pred_i]
+ char_target_i = [self.token_list[int(r)] for r in target[i]]
+
+ char_pred_i = "".join(char_pred_i).replace(self.space, " ")
+ char_pred_i = char_pred_i.replace(self.blank, "")
+
+ char_target_i = "".join(char_target_i).replace(self.space, " ")
+ char_target_i = char_target_i.replace(self.blank, "")
+
+ char_pred.append(char_pred_i)
+ char_target.append(char_target_i)
+
+ return char_pred, char_target
+
+ def calculate_cer(
+ self, char_pred: torch.Tensor, char_target: torch.Tensor
+ ) -> float:
+ """Calculate sentence-level CER score.
+ Args:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ Returns:
+ : Average sentence-level CER score.
+ """
+ import editdistance
+
+ distances, lens = [], []
+
+ for i, char_pred_i in enumerate(char_pred):
+ pred = char_pred_i.replace(" ", "")
+ target = char_target[i].replace(" ", "")
+ distances.append(editdistance.eval(pred, target))
+ lens.append(len(target))
+
+ return float(sum(distances)) / sum(lens)
+
+ def calculate_wer(
+ self, char_pred: torch.Tensor, char_target: torch.Tensor
+ ) -> float:
+ """Calculate sentence-level WER score.
+ Args:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ Returns:
+ : Average sentence-level WER score
+ """
+ import editdistance
+
+ distances, lens = [], []
+
+ for i, char_pred_i in enumerate(char_pred):
+ pred = char_pred_i.replace("鈻�", " ").split()
+ target = char_target[i].replace("鈻�", " ").split()
+
+ distances.append(editdistance.eval(pred, target))
+ lens.append(len(target))
+
+ return float(sum(distances)) / sum(lens)
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 79ca0b2..c347e24 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -407,7 +407,24 @@
return x + position_encoding
- def forward_chunk(self, x, cache=None):
+class StreamSinusoidalPositionEncoder(torch.nn.Module):
+ '''
+
+ '''
+ def __int__(self, d_model=80, dropout_rate=0.1):
+ pass
+
+ def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+ batch_size = positions.size(0)
+ positions = positions.type(dtype)
+ log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
+ inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
+ inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
+ scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+ encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+ return encoding.type(dtype)
+
+ def forward(self, x, cache=None):
start_idx = 0
pad_left = 0
pad_right = 0
@@ -419,8 +436,83 @@
positions = torch.arange(1, timesteps+start_idx+1)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
outputs = x + position_encoding[:, start_idx: start_idx + timesteps]
- outputs = outputs.transpose(1,2)
+ outputs = outputs.transpose(1, 2)
outputs = F.pad(outputs, (pad_left, pad_right))
- outputs = outputs.transpose(1,2)
+ outputs = outputs.transpose(1, 2)
return outputs
-
+
+class StreamingRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding.
+ Args:
+ size: Module size.
+ max_len: Maximum input length.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
+ ) -> None:
+ """Construct a RelativePositionalEncoding object."""
+ super().__init__()
+
+ self.size = size
+
+ self.pe = None
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+ self._register_load_state_dict_pre_hook(_pre_hook)
+
+ def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
+ """Reset positional encoding.
+ Args:
+ x: Input sequences. (B, T, ?)
+ left_context: Number of frames in left context.
+ """
+ time1 = x.size(1) + left_context
+
+ if self.pe is not None:
+ if self.pe.size(1) >= time1 * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(device=x.device, dtype=x.dtype)
+ return
+
+ pe_positive = torch.zeros(time1, self.size)
+ pe_negative = torch.zeros(time1, self.size)
+
+ position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.size, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.size)
+ )
+
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+
+ self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
+ dtype=x.dtype, device=x.device
+ )
+
+ def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute positional encoding.
+ Args:
+ x: Input sequences. (B, T, ?)
+ left_context: Number of frames in left context.
+ Returns:
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
+ """
+ self.extend_pe(x, left_context=left_context)
+
+ time1 = x.size(1) + left_context
+
+ pos_enc = self.pe[
+ :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
+ ]
+ pos_enc = self.dropout(pos_enc)
+
+ return pos_enc
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 6d77d69..5d4fe1c 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
-from typing import Dict
+from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -506,3 +506,196 @@
}
return activation_funcs[act]()
+
+class TooShortUttError(Exception):
+ """Raised when the utt is too short for subsampling.
+
+ Args:
+ message: Error message to display.
+ actual_size: The size that cannot pass the subsampling.
+ limit: The size limit for subsampling.
+
+ """
+
+ def __init__(self, message: str, actual_size: int, limit: int) -> None:
+ """Construct a TooShortUttError module."""
+ super().__init__(message)
+
+ self.actual_size = actual_size
+ self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+ """Check if the input is too short for subsampling.
+
+ Args:
+ sub_factor: Subsampling factor for Conv2DSubsampling.
+ size: Input size.
+
+ Returns:
+ : Whether an error should be sent.
+ : Size limit for specified subsampling factor.
+
+ """
+ if sub_factor == 2 and size < 3:
+ return True, 7
+ elif sub_factor == 4 and size < 7:
+ return True, 7
+ elif sub_factor == 6 and size < 11:
+ return True, 11
+
+ return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+ """Get conv2D second layer parameters for given subsampling factor.
+
+ Args:
+ sub_factor: Subsampling factor (1/X).
+ input_size: Input size.
+
+ Returns:
+ : Kernel size for second convolution.
+ : Stride for second convolution.
+ : Conv2DSubsampling output size.
+
+ """
+ if sub_factor == 2:
+ return 3, 1, (((input_size - 1) // 2 - 2))
+ elif sub_factor == 4:
+ return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+ elif sub_factor == 6:
+ return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+ else:
+ raise ValueError(
+ "subsampling_factor parameter should be set to either 2, 4 or 6."
+ )
+
+
+def make_chunk_mask(
+ size: int,
+ chunk_size: int,
+ left_chunk_size: int = 0,
+ device: torch.device = None,
+) -> torch.Tensor:
+ """Create chunk mask for the subsequent steps (size, size).
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ size: Size of the source mask.
+ chunk_size: Number of frames in chunk.
+ left_chunk_size: Size of the left context in chunks (0 means full context).
+ device: Device for the mask tensor.
+
+ Returns:
+ mask: Chunk mask. (size, size)
+
+ """
+ mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+ for i in range(size):
+ if left_chunk_size <= 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+ end = min((i // chunk_size + 1) * chunk_size, size)
+ mask[i, start:end] = True
+
+ return ~mask
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+ """Create source mask for given lengths.
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ lengths: Sequence lengths. (B,)
+
+ Returns:
+ : Mask for the sequence lengths. (B, max_len)
+
+ """
+ max_len = lengths.max()
+ batch_size = lengths.size(0)
+
+ expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+ return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+ labels: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Get Transducer loss I/O.
+
+ Args:
+ labels: Label ID sequences. (B, L)
+ encoder_out_lens: Encoder output lengths. (B,)
+ ignore_id: Padding symbol ID.
+ blank_id: Blank symbol ID.
+
+ Returns:
+ decoder_in: Decoder inputs. (B, U)
+ target: Target label ID sequences. (B, U)
+ t_len: Time lengths. (B,)
+ u_len: Label lengths. (B,)
+
+ """
+
+ def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+ """Create padded batch of labels from a list of labels sequences.
+
+ Args:
+ labels: Labels sequences. [B x (?)]
+ padding_value: Padding value.
+
+ Returns:
+ labels: Batch of padded labels sequences. (B,)
+
+ """
+ batch_size = len(labels)
+
+ padded = (
+ labels[0]
+ .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+ .fill_(padding_value)
+ )
+
+ for i in range(batch_size):
+ padded[i, : labels[i].size(0)] = labels[i]
+
+ return padded
+
+ device = labels.device
+
+ labels_unpad = [y[y != ignore_id] for y in labels]
+ blank = labels[0].new([blank_id])
+
+ decoder_in = pad_list(
+ [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+ ).to(device)
+
+ target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+ encoder_out_lens = list(map(int, encoder_out_lens))
+ t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+ u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+ return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+ """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+ if t.size(dim) == pad_len:
+ return t
+ else:
+ pad_size = list(t.shape)
+ pad_size[dim] = pad_len - t.size(dim)
+ return torch.cat(
+ [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+ )
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index a3d2676..2b2dac8 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -6,6 +6,8 @@
"""Repeat the same layer definition."""
+from typing import Dict, List, Optional
+
import torch
@@ -31,3 +33,92 @@
"""
return MultiSequential(*[fn(n) for n in range(N)])
+
+
+class MultiBlocks(torch.nn.Module):
+ """MultiBlocks definition.
+ Args:
+ block_list: Individual blocks of the encoder architecture.
+ output_size: Architecture output size.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ """
+
+ def __init__(
+ self,
+ block_list: List[torch.nn.Module],
+ output_size: int,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ ) -> None:
+ """Construct a MultiBlocks object."""
+ super().__init__()
+
+ self.blocks = torch.nn.ModuleList(block_list)
+ self.norm_blocks = norm_class(output_size)
+
+ self.num_blocks = len(block_list)
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ for idx in range(self.num_blocks):
+ self.blocks[idx].reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Forward each block of the encoder architecture.
+ Args:
+ x: MultiBlocks input sequences. (B, T, D_block_1)
+ pos_enc: Positional embedding sequences.
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Output sequences. (B, T, D_block_N)
+ """
+ for block_index, block in enumerate(self.blocks):
+ x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
+
+ x = self.norm_blocks(x)
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 0,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Forward each block of the encoder architecture.
+ Args:
+ x: MultiBlocks input sequences. (B, T, D_block_1)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: MultiBlocks output sequences. (B, T, D_block_N)
+ """
+ for block_idx, block in enumerate(self.blocks):
+ x, pos_enc = block.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ x = self.norm_blocks(x)
+
+ return x
diff --git a/funasr/modules/streaming_utils/chunk_utilis.py b/funasr/modules/streaming_utils/chunk_utilis.py
index ea37c68..ed8b31e 100644
--- a/funasr/modules/streaming_utils/chunk_utilis.py
+++ b/funasr/modules/streaming_utils/chunk_utilis.py
@@ -11,7 +11,7 @@
class overlap_chunk():
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index d492ccf..623be65 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -11,6 +11,10 @@
from funasr.modules.embedding import PositionalEncoding
import logging
from funasr.modules.streaming_utils.utils import sequence_mask
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
+from typing import Optional, Tuple, Union
+import math
+
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
@@ -407,3 +411,201 @@
var_dict_tf[name_tf].shape))
return var_dict_torch_update
+class StreamingConvInput(torch.nn.Module):
+ """Streaming ConvInput module definition.
+ Args:
+ input_size: Input size.
+ conv_size: Convolution size.
+ subsampling_factor: Subsampling factor.
+ vgg_like: Whether to use a VGG-like network.
+ output_size: Block output dimension.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ conv_size: Union[int, Tuple],
+ subsampling_factor: int = 4,
+ vgg_like: bool = True,
+ output_size: Optional[int] = None,
+ ) -> None:
+ """Construct a ConvInput object."""
+ super().__init__()
+ if vgg_like:
+ if subsampling_factor == 1:
+ conv_size1, conv_size2 = conv_size
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((1, 2)),
+ torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((1, 2)),
+ )
+
+ output_proj = conv_size2 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = 1
+
+ self.stride_1 = 1
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ else:
+ conv_size1, conv_size2 = conv_size
+
+ kernel_1 = int(subsampling_factor / 2)
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((kernel_1, 2)),
+ torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((2, 2)),
+ )
+
+ output_proj = conv_size2 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = subsampling_factor
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ self.stride_1 = kernel_1
+
+ else:
+ if subsampling_factor == 1:
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
+
+ self.subsampling_factor = subsampling_factor
+ self.kernel_2 = 3
+ self.stride_2 = 1
+
+ self.create_new_mask = self.create_new_conv2d_mask
+
+ else:
+ kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
+ subsampling_factor,
+ input_size,
+ )
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size * conv_2_output_size
+
+ self.subsampling_factor = subsampling_factor
+ self.kernel_2 = kernel_2
+ self.stride_2 = stride_2
+
+ self.create_new_mask = self.create_new_conv2d_mask
+
+ self.vgg_like = vgg_like
+ self.min_frame_length = 7
+
+ if output_size is not None:
+ self.output = torch.nn.Linear(output_proj, output_size)
+ self.output_size = output_size
+ else:
+ self.output = None
+ self.output_size = output_proj
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: ConvInput input sequences. (B, T, D_feats)
+ mask: Mask of input sequences. (B, 1, T)
+ Returns:
+ x: ConvInput output sequences. (B, sub(T), D_out)
+ mask: Mask of output sequences. (B, 1, sub(T))
+ """
+ if mask is not None:
+ mask = self.create_new_mask(mask)
+ olens = max(mask.eq(0).sum(1))
+
+ b, t, f = x.size()
+ x = x.unsqueeze(1) # (b. 1. t. f)
+
+ if chunk_size is not None:
+ max_input_length = int(
+ chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
+ )
+ x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+ x = list(x)
+ x = torch.stack(x, dim=0)
+ N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+ x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
+ x = self.conv(x)
+
+ _, c, _, f = x.size()
+ if chunk_size is not None:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+ else:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
+
+ if self.output is not None:
+ x = self.output(x)
+
+ return x, mask[:,:olens][:,:x.size(1)]
+
+ def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create a new mask for VGG output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
+ mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
+
+ vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
+ mask = mask[:, :vgg2_t_len][:, ::2]
+ else:
+ mask = mask
+
+ return mask
+
+ def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create new conformer mask for Conv2d output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+ else:
+ return mask
+
+ def get_size_before_subsampling(self, size: int) -> int:
+ """Return the original size before subsampling for a given size.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of frames before subsampling.
+ """
+ return size * self.subsampling_factor
diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md
index 82347be..23e618c 100644
--- a/funasr/runtime/grpc/Readme.md
+++ b/funasr/runtime/grpc/Readme.md
@@ -1,6 +1,9 @@
-## paraformer grpc onnx server in c++
+# Using funasr with grpc-cpp
-#### Step 1. Build ../onnxruntime as it's document
+## For the Server
+
+### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
+
```
#put onnx-lib & onnx-asr-model into /path/to/asrmodel(eg: /data/asrmodel)
ls /data/asrmodel/
@@ -10,7 +13,7 @@
```
-#### Step 2. Compile and install grpc v1.52.0 in case of grpc bugs
+### Compile and install grpc v1.52.0 in case of grpc bugs
```
export GRPC_INSTALL_DIR=/data/soft/grpc
export PKG_CONFIG_PATH=$GRPC_INSTALL_DIR/lib/pkgconfig
@@ -35,84 +38,149 @@
source ~/.bashrc
```
-#### Step 3. Compile and start grpc onnx paraformer server
+### Compile and start grpc onnx paraformer server
```
# set -DONNXRUNTIME_DIR=/path/to/asrmodel/onnxruntime-linux-x64-1.14.0
./rebuild.sh
```
-#### Step 4. Start grpc paraformer server
+### Start grpc paraformer server
```
Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false)
./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false
```
-#### Step 5. Start grpc python paraformer client on PC with MIC
-```
-cd ../python/grpc
-python grpc_main_client_mic.py --host $server_ip --port 10108
+## For the client
+
+### Install the requirements as in [grpc-python](./docs/grpc_python.md)
+
+```shell
+git clone https://github.com/alibaba/FunASR.git && cd FunASR
+cd funasr/runtime/python/grpc
+pip install -r requirements_client.txt
```
-The `grpc_main_client_mic.py` follows the [original design] (https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc#workflow-in-desgin) by sending audio_data with chunks. If you want to send audio_data in one request, here is an example:
+### Generate protobuf file
+Run on server, the two generated pb files are both used for server and client
+```shell
+# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
+# regenerate it only when you make changes to ./proto/paraformer.proto file.
+python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
-# go to ../python/grpc to find this package
-import paraformer_pb2
-
-class RecognizeStub:
- def __init__(self, channel):
- self.Recognize = channel.stream_stream(
- '/paraformer.ASR/Recognize',
- request_serializer=paraformer_pb2.Request.SerializeToString,
- response_deserializer=paraformer_pb2.Response.FromString,
- )
-
-
-async def send(channel, data, speaking, isEnd):
- stub = RecognizeStub(channel)
- req = paraformer_pb2.Request()
- if data:
- req.audio_data = data
- req.user = 'zz'
- req.language = 'zh-CN'
- req.speaking = speaking
- req.isEnd = isEnd
- q = queue.SimpleQueue()
- q.put(req)
- return stub.Recognize(iter(q.get, None))
-
-# send the audio data once
-async def grpc_rec(data, grpc_uri):
- with grpc.insecure_channel(grpc_uri) as channel:
- b = time.time()
- response = await send(channel, data, False, False)
- resp = response.next()
- text = ''
- if 'decoding' == resp.action:
- resp = response.next()
- if 'finish' == resp.action:
- text = json.loads(resp.sentence)['text']
- response = await send(channel, None, False, True)
- return {
- 'text': text,
- 'time': time.time() - b,
- }
-
-async def test():
- # fc = FunAsrGrpcClient('127.0.0.1', 9900)
- # t = await fc.rec(wav.tobytes())
- # print(t)
- wav, _ = sf.read('z-10s.wav', dtype='int16')
- uri = '127.0.0.1:9900'
- res = await grpc_rec(wav.tobytes(), uri)
- print(res)
-
-
-if __name__ == '__main__':
- asyncio.run(test())
-
+### Start grpc client
```
+# Start client.
+python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
+```
+
+[//]: # (```)
+
+[//]: # (# go to ../python/grpc to find this package)
+
+[//]: # (import paraformer_pb2)
+
+[//]: # ()
+[//]: # ()
+[//]: # (class RecognizeStub:)
+
+[//]: # ( def __init__(self, channel):)
+
+[//]: # ( self.Recognize = channel.stream_stream()
+
+[//]: # ( '/paraformer.ASR/Recognize',)
+
+[//]: # ( request_serializer=paraformer_pb2.Request.SerializeToString,)
+
+[//]: # ( response_deserializer=paraformer_pb2.Response.FromString,)
+
+[//]: # ( ))
+
+[//]: # ()
+[//]: # ()
+[//]: # (async def send(channel, data, speaking, isEnd):)
+
+[//]: # ( stub = RecognizeStub(channel))
+
+[//]: # ( req = paraformer_pb2.Request())
+
+[//]: # ( if data:)
+
+[//]: # ( req.audio_data = data)
+
+[//]: # ( req.user = 'zz')
+
+[//]: # ( req.language = 'zh-CN')
+
+[//]: # ( req.speaking = speaking)
+
+[//]: # ( req.isEnd = isEnd)
+
+[//]: # ( q = queue.SimpleQueue())
+
+[//]: # ( q.put(req))
+
+[//]: # ( return stub.Recognize(iter(q.get, None)))
+
+[//]: # ()
+[//]: # (# send the audio data once)
+
+[//]: # (async def grpc_rec(data, grpc_uri):)
+
+[//]: # ( with grpc.insecure_channel(grpc_uri) as channel:)
+
+[//]: # ( b = time.time())
+
+[//]: # ( response = await send(channel, data, False, False))
+
+[//]: # ( resp = response.next())
+
+[//]: # ( text = '')
+
+[//]: # ( if 'decoding' == resp.action:)
+
+[//]: # ( resp = response.next())
+
+[//]: # ( if 'finish' == resp.action:)
+
+[//]: # ( text = json.loads(resp.sentence)['text'])
+
+[//]: # ( response = await send(channel, None, False, True))
+
+[//]: # ( return {)
+
+[//]: # ( 'text': text,)
+
+[//]: # ( 'time': time.time() - b,)
+
+[//]: # ( })
+
+[//]: # ()
+[//]: # (async def test():)
+
+[//]: # ( # fc = FunAsrGrpcClient('127.0.0.1', 9900))
+
+[//]: # ( # t = await fc.rec(wav.tobytes()))
+
+[//]: # ( # print(t))
+
+[//]: # ( wav, _ = sf.read('z-10s.wav', dtype='int16'))
+
+[//]: # ( uri = '127.0.0.1:9900')
+
+[//]: # ( res = await grpc_rec(wav.tobytes(), uri))
+
+[//]: # ( print(res))
+
+[//]: # ()
+[//]: # ()
+[//]: # (if __name__ == '__main__':)
+
+[//]: # ( asyncio.run(test()))
+
+[//]: # ()
+[//]: # (```)
## Acknowledge
diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer_server.cc
index 2bfd3e5..2893d4c 100644
--- a/funasr/runtime/grpc/paraformer_server.cc
+++ b/funasr/runtime/grpc/paraformer_server.cc
@@ -128,7 +128,7 @@
stream->Write(res);
}
else {
- FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
+ FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 4ffe0f3..6feef92 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -2,24 +2,27 @@
project(FunASRonnx)
-set(CMAKE_CXX_STANDARD 11)
+# set(CMAKE_CXX_STANDARD 11)
+set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+include(TestBigEndian)
+test_big_endian(BIG_ENDIAN)
+if(BIG_ENDIAN)
+ message("Big endian system")
+else()
+ message("Little endian system")
+endif()
+
# for onnxruntime
-
IF(WIN32)
-
-
if(CMAKE_CL_64)
link_directories(${ONNXRUNTIME_DIR}\\lib)
else()
add_definitions(-D_WIN_X86)
endif()
ELSE()
-
-
-link_directories(${ONNXRUNTIME_DIR}/lib)
-
+ link_directories(${ONNXRUNTIME_DIR}/lib)
endif()
add_subdirectory("./third_party/yaml-cpp")
diff --git a/funasr/runtime/onnxruntime/include/Audio.h b/funasr/runtime/onnxruntime/include/Audio.h
index da5e82c..ec49a9f 100644
--- a/funasr/runtime/onnxruntime/include/Audio.h
+++ b/funasr/runtime/onnxruntime/include/Audio.h
@@ -6,6 +6,13 @@
#include <queue>
#include <stdint.h>
+#ifndef model_sample_rate
+#define model_sample_rate 16000
+#endif
+#ifndef WAV_HEADER_SIZE
+#define WAV_HEADER_SIZE 44
+#endif
+
using namespace std;
class AudioFrame {
@@ -32,7 +39,6 @@
int16_t *speech_buff;
int speech_len;
int speech_align_len;
- int16_t sample_rate;
int offset;
float align_size;
int data_type;
@@ -43,10 +49,11 @@
Audio(int data_type, int size);
~Audio();
void disp();
- bool loadwav(const char* filename);
- bool loadwav(const char* buf, int nLen);
- bool loadpcmwav(const char* buf, int nFileLen);
- bool loadpcmwav(const char* filename);
+ bool loadwav(const char* filename, int32_t* sampling_rate);
+ void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
+ bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
+ bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
+ bool loadpcmwav(const char* filename, int32_t* sampling_rate);
int fetch_chunck(float *&dout, int len);
int fetch(float *&dout, int &len, int &flag);
void padding();
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index 6e81fa9..9bc37e7 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -55,9 +55,9 @@
// if not give a fnCallback ,it should be NULL
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md
index b234e16..f7be2e0 100644
--- a/funasr/runtime/onnxruntime/readme.md
+++ b/funasr/runtime/onnxruntime/readme.md
@@ -1,82 +1,77 @@
+# ONNXRuntime-cpp
-## 蹇�熶娇鐢�
+## Export the model
+### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
-### Windows
-
- 瀹夎Vs2022 鎵撳紑cpp_onnx鐩綍涓嬬殑cmake宸ョ▼锛岀洿鎺� build鍗冲彲銆� 鏈粨搴撳凡缁忓噯澶囧ソ鎵�鏈夌浉鍏充緷璧栧簱銆�
-
- Windows涓嬪凡缁忛缃甪ftw3鍙妎nnxruntime搴�
-
-### Linux
-See the bottom of this page: Building Guidance
-
-### 杩愯绋嬪簭
-
-tester /path/to/models_dir /path/to/wave_file quantize(true or false)
-
-渚嬪锛� tester /data/models /data/test.wav false
-
-/data/models 闇�瑕佸寘鎷涓嬩笁涓枃浠�: config.yaml, am.mvn, model.onnx(or model_quant.onnx)
-
-## 鏀寔骞冲彴
-- Windows
-- Linux/Unix
-
-## 渚濊禆
-- fftw3
-- openblas
-- onnxruntime
-
-## 瀵煎嚭onnx鏍煎紡妯″瀷鏂囦欢
-瀹夎 modelscope涓嶧unASR锛屼緷璧栵細torch锛宼orchaudio锛屽畨瑁呰繃绋媅璇︾粏鍙傝�冩枃妗(https://github.com/alibaba-damo-academy/FunASR/wiki)
```shell
-pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
-git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install --editable ./
+pip3 install torch torchaudio
+pip install -U modelscope
+pip install -U funasr
```
-瀵煎嚭onnx妯″瀷锛孾璇﹁](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)锛屽弬鑰冪ず渚嬶紝浠巑odelscope涓ā鍨嬪鍑猴細
+
+### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
-## Building Guidance for Linux/Unix
+## Building for Linux/Unix
-```
-git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
-mkdir build
-cd build
+### Download onnxruntime
+```shell
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
# here we get a copy of onnxruntime for linux 64
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
-# ls
-# onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz
-
-#install fftw3-dev
-ubuntu: apt install libfftw3-dev
-centos: yum install fftw fftw-devel
-
-#install openblas
-bash ./third_party/install_openblas.sh
-
-# build
- cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
- make
-
- # then in the subfolder tester of current direcotry, you will see a program, tester
-
-````
-
-### The structure of a qualified onnxruntime package.
-```
-onnxruntime_xxx
-鈹溾攢鈹�鈹�include
-鈹斺攢鈹�鈹�lib
```
-## 娉ㄦ剰
-鏈▼搴忓彧鏀寔 閲囨牱鐜�16000hz, 浣嶆繁16bit鐨� **鍗曞0閬�** 闊抽銆�
+### Install fftw3
+```shell
+sudo apt install libfftw3-dev #ubuntu
+# sudo yum install fftw fftw-devel #centos
+```
+
+### Install openblas
+```shell
+sudo apt-get install libopenblas-dev #ubuntu
+# sudo yum -y install openblas-devel #centos
+```
+
+### Build runtime
+```shell
+git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
+mkdir build && cd build
+cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
+make
+```
+
+[//]: # (### The structure of a qualified onnxruntime package.)
+
+[//]: # (```)
+
+[//]: # (onnxruntime_xxx)
+
+[//]: # (鈹溾攢鈹�鈹�include)
+
+[//]: # (鈹斺攢鈹�鈹�lib)
+
+[//]: # (```)
+
+## Building for Windows
+
+Ref to win/
+
+
+## Run the demo
+
+```shell
+tester /path/models_dir /path/wave_file quantize(true or false)
+```
+
+The structure of /path/models_dir
+```
+config.yaml, am.mvn, model.onnx(or model_quant.onnx)
+```
## Acknowledge
diff --git a/funasr/runtime/onnxruntime/src/Audio.cpp b/funasr/runtime/onnxruntime/src/Audio.cpp
index bce3a90..38b6de8 100644
--- a/funasr/runtime/onnxruntime/src/Audio.cpp
+++ b/funasr/runtime/onnxruntime/src/Audio.cpp
@@ -3,10 +3,95 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
+#include <fstream>
+#include <assert.h>
#include "Audio.h"
+#include "precomp.h"
using namespace std;
+
+// see http://soundfile.sapp.org/doc/WaveFormat/
+// Note: We assume little endian here
+struct WaveHeader {
+ bool Validate() const {
+ // F F I R
+ if (chunk_id != 0x46464952) {
+ printf("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
+ return false;
+ }
+ // E V A W
+ if (format != 0x45564157) {
+ printf("Expected format WAVE. Given: 0x%08x\n", format);
+ return false;
+ }
+
+ if (subchunk1_id != 0x20746d66) {
+ printf("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
+ subchunk1_id);
+ return false;
+ }
+
+ if (subchunk1_size != 16) { // 16 for PCM
+ printf("Expected subchunk1_size 16. Given: %d\n",
+ subchunk1_size);
+ return false;
+ }
+
+ if (audio_format != 1) { // 1 for PCM
+ printf("Expected audio_format 1. Given: %d\n", audio_format);
+ return false;
+ }
+
+ if (num_channels != 1) { // we support only single channel for now
+ printf("Expected single channel. Given: %d\n", num_channels);
+ return false;
+ }
+ if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
+ return false;
+ }
+
+ if (block_align != (num_channels * bits_per_sample / 8)) {
+ return false;
+ }
+
+ if (bits_per_sample != 16) { // we support only 16 bits per sample
+ printf("Expected bits_per_sample 16. Given: %d\n",
+ bits_per_sample);
+ return false;
+ }
+ return true;
+ }
+
+ // See https://en.wikipedia.org/wiki/WAV#Metadata and
+ // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
+ void SeekToDataChunk(std::istream &is) {
+ // a t a d
+ while (is && subchunk2_id != 0x61746164) {
+ // const char *p = reinterpret_cast<const char *>(&subchunk2_id);
+ // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
+ // p[1], p[2], p[3], subchunk2_size);
+ is.seekg(subchunk2_size, std::istream::cur);
+ is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t));
+ is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t));
+ }
+ }
+
+ int32_t chunk_id;
+ int32_t chunk_size;
+ int32_t format;
+ int32_t subchunk1_id;
+ int32_t subchunk1_size;
+ int16_t audio_format;
+ int16_t num_channels;
+ int32_t sample_rate;
+ int32_t byte_rate;
+ int16_t block_align;
+ int16_t bits_per_sample;
+ int32_t subchunk2_id; // a tag of this chunk
+ int32_t subchunk2_size; // size of subchunk2
+};
+static_assert(sizeof(WaveHeader) == WAV_HEADER_SIZE, "");
class AudioWindow {
private:
@@ -56,7 +141,7 @@
float frame_length = 400;
float frame_shift = 160;
float num_new_samples =
- ceil((num_samples - 400) / frame_shift) * frame_shift + frame_length;
+ ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
end = start + num_new_samples;
len = (int)num_new_samples;
@@ -111,62 +196,95 @@
void Audio::disp()
{
- printf("Audio time is %f s. len is %d\n", (float)speech_len / 16000,
+ printf("Audio time is %f s. len is %d\n", (float)speech_len / model_sample_rate,
speech_len);
}
float Audio::get_time_len()
{
- return (float)speech_len / 16000;
- //speech_len);
+ return (float)speech_len / model_sample_rate;
}
-bool Audio::loadwav(const char *filename)
+void Audio::wavResample(int32_t sampling_rate, const float *waveform,
+ int32_t n)
{
+ printf(
+ "Creating a resampler:\n"
+ " in_sample_rate: %d\n"
+ " output_sample_rate: %d\n",
+ sampling_rate, static_cast<int32_t>(model_sample_rate));
+ float min_freq =
+ std::min<int32_t>(sampling_rate, model_sample_rate);
+ float lowpass_cutoff = 0.99 * 0.5 * min_freq;
+ int32_t lowpass_filter_width = 6;
+ //FIXME
+ //auto resampler = new LinearResample(
+ // sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
+ auto resampler = std::make_unique<LinearResample>(
+ sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
+ std::vector<float> samples;
+ resampler->Resample(waveform, n, true, &samples);
+ //reset speech_data
+ speech_len = samples.size();
+ if (speech_data != NULL) {
+ free(speech_data);
+ }
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
+ copy(samples.begin(), samples.end(), speech_data);
+}
+
+bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
+{
+ WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
-
+
offset = 0;
-
- FILE *fp;
- fp = fopen(filename, "rb");
- if (fp == nullptr)
+ std::ifstream is(filename, std::ifstream::binary);
+ is.read(reinterpret_cast<char *>(&header), sizeof(header));
+ if(!is){
+ fprintf(stderr, "Failed to read %s\n", filename);
return false;
- fseek(fp, 0, SEEK_END); /*瀹氫綅鍒版枃浠舵湯灏�*/
- uint32_t nFileLen = ftell(fp); /*寰楀埌鏂囦欢澶у皬*/
- fseek(fp, 44, SEEK_SET); /*璺宠繃wav鏂囦欢澶�*/
-
- speech_len = (nFileLen - 44) / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_align_len);
+ }
+
+ *sampling_rate = header.sample_rate;
+ // header.subchunk2_size contains the number of bytes in the data.
+ // As we assume each sample contains two bytes, so it is divided by 2 here
+ speech_len = header.subchunk2_size / 2;
+ speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
- int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
- fclose(fp);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+ is.read(reinterpret_cast<char *>(speech_buff), header.subchunk2_size);
+ if (!is) {
+ fprintf(stderr, "Failed to read %s\n", filename);
+ return false;
+ }
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
-
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
-
return true;
}
@@ -174,57 +292,54 @@
return false;
}
-
-bool Audio::loadwav(const char* buf, int nFileLen)
+bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
{
-
-
-
+ WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
-
offset = 0;
- size_t nOffset = 0;
+ std::memcpy(&header, buf, sizeof(header));
-#define WAV_HEADER_SIZE 44
-
- speech_len = (nFileLen - WAV_HEADER_SIZE) / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
+ *sampling_rate = header.sample_rate;
+ speech_len = header.subchunk2_size / 2;
+ speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)(buf + WAV_HEADER_SIZE), speech_len * sizeof(int16_t));
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
+
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
+ }
+ AudioFrame* frame = new AudioFrame(speech_len);
+ frame_queue.push(frame);
return true;
}
else
return false;
-
}
-
-bool Audio::loadpcmwav(const char* buf, int nBufLen)
+bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
@@ -234,32 +349,28 @@
}
offset = 0;
- size_t nOffset = 0;
-
-
-
speech_len = nBufLen / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
-
-
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
@@ -269,13 +380,10 @@
}
else
return false;
-
-
}
-bool Audio::loadpcmwav(const char* filename)
+bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
{
-
if (speech_data != NULL) {
free(speech_data);
}
@@ -293,34 +401,31 @@
fseek(fp, 0, SEEK_SET);
speech_len = (nFileLen) / 2;
- speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
- speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
- memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
fclose(fp);
- speech_data = (float*)malloc(sizeof(float) * speech_align_len);
- memset(speech_data, 0, sizeof(float) * speech_align_len);
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
-
-
- int i;
float scale = 1;
-
if (data_type == 1) {
scale = 32768;
}
-
- for (i = 0; i < speech_len; i++) {
+ for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
+ //resample
+ if(*sampling_rate != model_sample_rate){
+ wavResample(*sampling_rate, speech_data, speech_len);
+ }
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
-
return true;
}
@@ -328,7 +433,6 @@
return false;
}
-
int Audio::fetch_chunck(float *&dout, int len)
{
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index c07aac5..d41fcd0 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,5 +1,6 @@
file(GLOB files1 "*.cpp")
+file(GLOB files2 "*.cc")
file(GLOB files4 "paraformer/*.cpp")
set(files ${files1} ${files2} ${files3} ${files4})
diff --git a/funasr/runtime/onnxruntime/src/Vocab.cpp b/funasr/runtime/onnxruntime/src/Vocab.cpp
index af6312b..b54a6c6 100644
--- a/funasr/runtime/onnxruntime/src/Vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/Vocab.cpp
@@ -13,21 +13,6 @@
{
ifstream in(filename);
loadVocabFromYaml(filename);
-
- /*
- string line;
- if (in) // 鏈夎鏂囦欢
- {
- while (getline(in, line)) // line涓笉鍖呮嫭姣忚鐨勬崲琛岀
- {
- vocab.push_back(line);
- }
- }
- else{
- printf("Cannot load vocab from: %s, there must be file vocab.txt", filename);
- exit(-1);
- }
- */
}
Vocab::~Vocab()
{
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index 0d77d20..a2ecf10 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -17,8 +17,9 @@
if (!pRecogObj)
return nullptr;
+ int32_t sampling_rate = -1;
Audio audio(1);
- if (!audio.loadwav(szBuf, nLen))
+ if (!audio.loadwav(szBuf, nLen, &sampling_rate))
return nullptr;
//audio.split();
@@ -41,14 +42,14 @@
return pResult;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
Audio audio(1);
- if (!audio.loadpcmwav(szBuf, nLen))
+ if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
return nullptr;
//audio.split();
@@ -71,14 +72,14 @@
return pResult;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
Audio audio(1);
- if (!audio.loadpcmwav(szFileName))
+ if (!audio.loadpcmwav(szFileName, &sampling_rate))
return nullptr;
//audio.split();
@@ -106,9 +107,10 @@
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
-
+
+ int32_t sampling_rate = -1;
Audio audio(1);
- if(!audio.loadwav(szWavfile))
+ if(!audio.loadwav(szWavfile, &sampling_rate))
return nullptr;
//audio.split();
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 678cdf6..695e0f7 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,7 +4,7 @@
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
-{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
string model_path;
string cmvn_path;
string config_path;
@@ -29,20 +29,20 @@
#ifdef _WIN32
wstring wstrPath = strToWstr(model_path);
- m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
+ m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
- m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
+ m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
string strName;
- getInputName(m_session, strName);
+ getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
- getInputName(m_session, strName,1);
+ getInputName(m_session.get(), strName,1);
m_strInputNames.push_back(strName);
- getOutputName(m_session, strName);
+ getOutputName(m_session.get(), strName);
m_strOutputNames.push_back(strName);
- getOutputName(m_session, strName,1);
+ getOutputName(m_session.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@@ -55,11 +55,6 @@
ModelImp::~ModelImp()
{
- if (m_session)
- {
- delete m_session;
- m_session = nullptr;
- }
if(vocab)
delete vocab;
fftwf_free(fft_input);
@@ -70,7 +65,6 @@
void ModelImp::reset()
{
- printf("Not Imp!!!!!!\n");
}
void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -173,6 +167,12 @@
apply_cmvn(in);
Ort::RunOptions run_option;
+#ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+
std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
in->buff,
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index e763be2..8946ae1 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -24,15 +24,9 @@
string greedy_search( float* in, int nLen);
-#ifdef _WIN_X86
- Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
-#else
- Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
-#endif
-
- Ort::Session* m_session = nullptr;
- Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
- Ort::SessionOptions sessionOptions = Ort::SessionOptions();
+ std::unique_ptr<Ort::Session> m_session;
+ Ort::Env env_;
+ Ort::SessionOptions sessionOptions;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 678a3e4..3aeed14 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -44,6 +44,7 @@
#include "FeatureQueue.h"
#include "SpeechWrap.h"
#include <Audio.h>
+#include "resample.h"
#include "Model.h"
#include "paraformer_onnx.h"
#include "libfunasrapi.h"
diff --git a/funasr/runtime/onnxruntime/src/resample.cc b/funasr/runtime/onnxruntime/src/resample.cc
new file mode 100644
index 0000000..0238752
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/resample.cc
@@ -0,0 +1,305 @@
+/**
+ * Copyright 2013 Pegah Ghahremani
+ * 2014 IMSL, PKU-HKUST (author: Wei Shi)
+ * 2014 Yanqing Sun, Junjie Wang
+ * 2014 Johns Hopkins University (author: Daniel Povey)
+ * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// this file is copied and modified from
+// kaldi/src/feat/resample.cc
+
+#include "resample.h"
+
+#include <assert.h>
+#include <math.h>
+#include <stdio.h>
+
+#include <cstdlib>
+#include <type_traits>
+
+#ifndef M_2PI
+#define M_2PI 6.283185307179586476925286766559005
+#endif
+
+#ifndef M_PI
+#define M_PI 3.1415926535897932384626433832795
+#endif
+
+template <class I>
+I Gcd(I m, I n) {
+ // this function is copied from kaldi/src/base/kaldi-math.h
+ if (m == 0 || n == 0) {
+ if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
+ fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
+ exit(-1);
+ }
+ return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
+ // return absolute value of whichever is nonzero
+ }
+ // could use compile-time assertion
+ // but involves messing with complex template stuff.
+ static_assert(std::is_integral<I>::value, "");
+ while (1) {
+ m %= n;
+ if (m == 0) return (n > 0 ? n : -n);
+ n %= m;
+ if (n == 0) return (m > 0 ? m : -m);
+ }
+}
+
+/// Returns the least common multiple of two integers. Will
+/// crash unless the inputs are positive.
+template <class I>
+I Lcm(I m, I n) {
+ // This function is copied from kaldi/src/base/kaldi-math.h
+ assert(m > 0 && n > 0);
+ I gcd = Gcd(m, n);
+ return gcd * (m / gcd) * (n / gcd);
+}
+
+static float DotProduct(const float *a, const float *b, int32_t n) {
+ float sum = 0;
+ for (int32_t i = 0; i != n; ++i) {
+ sum += a[i] * b[i];
+ }
+ return sum;
+}
+
+LinearResample::LinearResample(int32_t samp_rate_in_hz,
+ int32_t samp_rate_out_hz, float filter_cutoff_hz,
+ int32_t num_zeros)
+ : samp_rate_in_(samp_rate_in_hz),
+ samp_rate_out_(samp_rate_out_hz),
+ filter_cutoff_(filter_cutoff_hz),
+ num_zeros_(num_zeros) {
+ assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 &&
+ filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz &&
+ filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0);
+
+ // base_freq is the frequency of the repeating unit, which is the gcd
+ // of the input frequencies.
+ int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_);
+ input_samples_in_unit_ = samp_rate_in_ / base_freq;
+ output_samples_in_unit_ = samp_rate_out_ / base_freq;
+
+ SetIndexesAndWeights();
+ Reset();
+}
+
+void LinearResample::SetIndexesAndWeights() {
+ first_index_.resize(output_samples_in_unit_);
+ weights_.resize(output_samples_in_unit_);
+
+ double window_width = num_zeros_ / (2.0 * filter_cutoff_);
+
+ for (int32_t i = 0; i < output_samples_in_unit_; i++) {
+ double output_t = i / static_cast<double>(samp_rate_out_);
+ double min_t = output_t - window_width, max_t = output_t + window_width;
+ // we do ceil on the min and floor on the max, because if we did it
+ // the other way around we would unnecessarily include indexes just
+ // outside the window, with zero coefficients. It's possible
+ // if the arguments to the ceil and floor expressions are integers
+ // (e.g. if filter_cutoff_ has an exact ratio with the sample rates),
+ // that we unnecessarily include something with a zero coefficient,
+ // but this is only a slight efficiency issue.
+ int32_t min_input_index = ceil(min_t * samp_rate_in_),
+ max_input_index = floor(max_t * samp_rate_in_),
+ num_indices = max_input_index - min_input_index + 1;
+ first_index_[i] = min_input_index;
+ weights_[i].resize(num_indices);
+ for (int32_t j = 0; j < num_indices; j++) {
+ int32_t input_index = min_input_index + j;
+ double input_t = input_index / static_cast<double>(samp_rate_in_),
+ delta_t = input_t - output_t;
+ // sign of delta_t doesn't matter.
+ weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_;
+ }
+ }
+}
+
+/** Here, t is a time in seconds representing an offset from
+ the center of the windowed filter function, and FilterFunction(t)
+ returns the windowed filter function, described
+ in the header as h(t) = f(t)g(t), evaluated at t.
+*/
+float LinearResample::FilterFunc(float t) const {
+ float window, // raised-cosine (Hanning) window of width
+ // num_zeros_/2*filter_cutoff_
+ filter; // sinc filter function
+ if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
+ window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t));
+ else
+ window = 0.0; // outside support of window function
+ if (t != 0)
+ filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t);
+ else
+ filter = 2 * filter_cutoff_; // limit of the function at t = 0
+ return filter * window;
+}
+
+void LinearResample::Reset() {
+ input_sample_offset_ = 0;
+ output_sample_offset_ = 0;
+ input_remainder_.resize(0);
+}
+
+void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
+ std::vector<float> *output) {
+ int64_t tot_input_samp = input_sample_offset_ + input_dim,
+ tot_output_samp = GetNumOutputSamples(tot_input_samp, flush);
+
+ assert(tot_output_samp >= output_sample_offset_);
+
+ output->resize(tot_output_samp - output_sample_offset_);
+
+ // samp_out is the index into the total output signal, not just the part
+ // of it we are producing here.
+ for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp;
+ samp_out++) {
+ int64_t first_samp_in;
+ int32_t samp_out_wrapped;
+ GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped);
+ const std::vector<float> &weights = weights_[samp_out_wrapped];
+ // first_input_index is the first index into "input" that we have a weight
+ // for.
+ int32_t first_input_index =
+ static_cast<int32_t>(first_samp_in - input_sample_offset_);
+ float this_output;
+ if (first_input_index >= 0 &&
+ first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) {
+ this_output =
+ DotProduct(input + first_input_index, weights.data(), weights.size());
+ } else { // Handle edge cases.
+ this_output = 0.0;
+ for (int32_t i = 0; i < static_cast<int32_t>(weights.size()); i++) {
+ float weight = weights[i];
+ int32_t input_index = first_input_index + i;
+ if (input_index < 0 &&
+ static_cast<int32_t>(input_remainder_.size()) + input_index >= 0) {
+ this_output +=
+ weight * input_remainder_[input_remainder_.size() + input_index];
+ } else if (input_index >= 0 && input_index < input_dim) {
+ this_output += weight * input[input_index];
+ } else if (input_index >= input_dim) {
+ // We're past the end of the input and are adding zero; should only
+ // happen if the user specified flush == true, or else we would not
+ // be trying to output this sample.
+ assert(flush);
+ }
+ }
+ }
+ int32_t output_index =
+ static_cast<int32_t>(samp_out - output_sample_offset_);
+ (*output)[output_index] = this_output;
+ }
+
+ if (flush) {
+ Reset(); // Reset the internal state.
+ } else {
+ SetRemainder(input, input_dim);
+ input_sample_offset_ = tot_input_samp;
+ output_sample_offset_ = tot_output_samp;
+ }
+}
+
+int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
+ bool flush) const {
+ // For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
+ // where tick_freq is the least common multiple of samp_rate_in_ and
+ // samp_rate_out_.
+ int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_);
+ int32_t ticks_per_input_period = tick_freq / samp_rate_in_;
+
+ // work out the number of ticks in the time interval
+ // [ 0, input_num_samp/samp_rate_in_ ).
+ int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period;
+ if (!flush) {
+ float window_width = num_zeros_ / (2.0 * filter_cutoff_);
+ // To count the window-width in ticks we take the floor. This
+ // is because since we're looking for the largest integer num-out-samp
+ // that fits in the interval, which is open on the right, a reduction
+ // in interval length of less than a tick will never make a difference.
+ // For example, the largest integer in the interval [ 0, 2 ) and the
+ // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
+ // So when we're subtracting the window-width we can ignore the fractional
+ // part.
+ int32_t window_width_ticks = floor(window_width * tick_freq);
+ // The time-period of the output that we can sample gets reduced
+ // by the window-width (which is actually the distance from the
+ // center to the edge of the windowing function) if we're not
+ // "flushing the output".
+ interval_length_in_ticks -= window_width_ticks;
+ }
+ if (interval_length_in_ticks <= 0) return 0;
+
+ int32_t ticks_per_output_period = tick_freq / samp_rate_out_;
+ // Get the last output-sample in the closed interval, i.e. replacing [ ) with
+ // [ ]. Note: integer division rounds down. See
+ // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
+ // the notation.
+ int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period;
+ // We need the last output-sample in the open interval, so if it takes us to
+ // the end of the interval exactly, subtract one.
+ if (last_output_samp * ticks_per_output_period == interval_length_in_ticks)
+ last_output_samp--;
+
+ // First output-sample index is zero, so the number of output samples
+ // is the last output-sample plus one.
+ int64_t num_output_samp = last_output_samp + 1;
+ return num_output_samp;
+}
+
+// inline
+void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in,
+ int32_t *samp_out_wrapped) const {
+ // A unit is the smallest nonzero amount of time that is an exact
+ // multiple of the input and output sample periods. The unit index
+ // is the answer to "which numbered unit we are in".
+ int64_t unit_index = samp_out / output_samples_in_unit_;
+ // samp_out_wrapped is equal to samp_out % output_samples_in_unit_
+ *samp_out_wrapped =
+ static_cast<int32_t>(samp_out - unit_index * output_samples_in_unit_);
+ *first_samp_in =
+ first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_;
+}
+
+void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
+ std::vector<float> old_remainder(input_remainder_);
+ // max_remainder_needed is the width of the filter from side to side,
+ // measured in input samples. you might think it should be half that,
+ // but you have to consider that you might be wanting to output samples
+ // that are "in the past" relative to the beginning of the latest
+ // input... anyway, storing more remainder than needed is not harmful.
+ int32_t max_remainder_needed =
+ ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
+ input_remainder_.resize(max_remainder_needed);
+ for (int32_t index = -static_cast<int32_t>(input_remainder_.size());
+ index < 0; index++) {
+ // we interpret "index" as an offset from the end of "input" and
+ // from the end of input_remainder_.
+ int32_t input_index = index + input_dim;
+ if (input_index >= 0) {
+ input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
+ input[input_index];
+ } else if (input_index + static_cast<int32_t>(old_remainder.size()) >= 0) {
+ input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
+ old_remainder[input_index +
+ static_cast<int32_t>(old_remainder.size())];
+ // else leave it at zero.
+ }
+ }
+}
diff --git a/funasr/runtime/onnxruntime/src/resample.h b/funasr/runtime/onnxruntime/src/resample.h
new file mode 100644
index 0000000..b9a283a
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/resample.h
@@ -0,0 +1,137 @@
+/**
+ * Copyright 2013 Pegah Ghahremani
+ * 2014 IMSL, PKU-HKUST (author: Wei Shi)
+ * 2014 Yanqing Sun, Junjie Wang
+ * 2014 Johns Hopkins University (author: Daniel Povey)
+ * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// this file is copied and modified from
+// kaldi/src/feat/resample.h
+
+#include <cstdint>
+#include <vector>
+
+
+/*
+ We require that the input and output sampling rate be specified as
+ integers, as this is an easy way to specify that their ratio be rational.
+*/
+
+class LinearResample {
+ public:
+ /// Constructor. We make the input and output sample rates integers, because
+ /// we are going to need to find a common divisor. This should just remind
+ /// you that they need to be integers. The filter cutoff needs to be less
+ /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros
+ /// controls the sharpness of the filter, more == sharper but less efficient.
+ /// We suggest around 4 to 10 for normal use.
+ LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz,
+ float filter_cutoff_hz, int32_t num_zeros);
+
+ /// Calling the function Reset() resets the state of the object prior to
+ /// processing a new signal; it is only necessary if you have called
+ /// Resample(x, x_size, false, y) for some signal, leading to a remainder of
+ /// the signal being called, but then abandon processing the signal before
+ /// calling Resample(x, x_size, true, y) for the last piece. Call it
+ /// unnecessarily between signals will not do any harm.
+ void Reset();
+
+ /// This function does the resampling. If you call it with flush == true and
+ /// you have never called it with flush == false, it just resamples the input
+ /// signal (it resizes the output to a suitable number of samples).
+ ///
+ /// You can also use this function to process a signal a piece at a time.
+ /// suppose you break it into piece1, piece2, ... pieceN. You can call
+ /// \code{.cc}
+ /// Resample(piece1, piece1_size, false, &output1);
+ /// Resample(piece2, piece2_size, false, &output2);
+ /// Resample(piece3, piece3_size, true, &output3);
+ /// \endcode
+ /// If you call it with flush == false, it won't output the last few samples
+ /// but will remember them, so that if you later give it a second piece of
+ /// the input signal it can process it correctly.
+ /// If your most recent call to the object was with flush == false, it will
+ /// have internal state; you can remove this by calling Reset().
+ /// Empty input is acceptable.
+ void Resample(const float *input, int32_t input_dim, bool flush,
+ std::vector<float> *output);
+
+ //// Return the input and output sampling rates (for checks, for example)
+ int32_t GetInputSamplingRate() const { return samp_rate_in_; }
+ int32_t GetOutputSamplingRate() const { return samp_rate_out_; }
+
+ private:
+ void SetIndexesAndWeights();
+
+ float FilterFunc(float) const;
+
+ /// This function outputs the number of output samples we will output
+ /// for a signal with "input_num_samp" input samples. If flush == true,
+ /// we return the largest n such that
+ /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ),
+ /// and note that the interval is half-open. If flush == false,
+ /// define window_width as num_zeros / (2.0 * filter_cutoff_);
+ /// we return the largest n such that (n/samp_rate_out_) is in the interval
+ /// [ 0, input_num_samp/samp_rate_in_ - window_width ).
+ int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const;
+
+ /// Given an output-sample index, this function outputs to *first_samp_in the
+ /// first input-sample index that we have a weight on (may be negative),
+ /// and to *samp_out_wrapped the index into weights_ where we can get the
+ /// corresponding weights on the input.
+ inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in,
+ int32_t *samp_out_wrapped) const;
+
+ void SetRemainder(const float *input, int32_t input_dim);
+
+ private:
+ // The following variables are provided by the user.
+ int32_t samp_rate_in_;
+ int32_t samp_rate_out_;
+ float filter_cutoff_;
+ int32_t num_zeros_;
+
+ int32_t input_samples_in_unit_; ///< The number of input samples in the
+ ///< smallest repeating unit: num_samp_in_ =
+ ///< samp_rate_in_hz / Gcd(samp_rate_in_hz,
+ ///< samp_rate_out_hz)
+
+ int32_t output_samples_in_unit_; ///< The number of output samples in the
+ ///< smallest repeating unit: num_samp_out_
+ ///< = samp_rate_out_hz /
+ ///< Gcd(samp_rate_in_hz, samp_rate_out_hz)
+
+ /// The first input-sample index that we sum over, for this output-sample
+ /// index. May be negative; any truncation at the beginning is handled
+ /// separately. This is just for the first few output samples, but we can
+ /// extrapolate the correct input-sample index for arbitrary output samples.
+ std::vector<int32_t> first_index_;
+
+ /// Weights on the input samples, for this output-sample index.
+ std::vector<std::vector<float>> weights_;
+
+ // the following variables keep track of where we are in a particular signal,
+ // if it is being provided over multiple calls to Resample().
+
+ int64_t input_sample_offset_; ///< The number of input samples we have
+ ///< already received for this signal
+ ///< (including anything in remainder_)
+ int64_t output_sample_offset_; ///< The number of samples we have already
+ ///< output for this signal.
+ std::vector<float> input_remainder_; ///< A small trailing part of the
+ ///< previously seen input signal.
+};
diff --git a/funasr/runtime/python/grpc/Readme.md b/funasr/runtime/python/grpc/Readme.md
index 8c14985..822cb52 100644
--- a/funasr/runtime/python/grpc/Readme.md
+++ b/funasr/runtime/python/grpc/Readme.md
@@ -1,8 +1,6 @@
-# Using paraformer with grpc
+# Using funasr with grpc-python
We can send streaming audio data to server in real-time with grpc client every 10 ms e.g., and get transcribed text when stop speaking.
The audio data is in streaming, the asr inference process is in offline.
-
-
## For the Server
diff --git a/funasr/runtime/python/libtorch/README.md b/funasr/runtime/python/libtorch/README.md
index aeb2eae..fd64cc6 100644
--- a/funasr/runtime/python/libtorch/README.md
+++ b/funasr/runtime/python/libtorch/README.md
@@ -1,57 +1,54 @@
-## Using funasr with libtorch
+# Libtorch-python
-[FunASR](https://github.com/alibaba-damo-academy/FunASR) hopes to build a bridge between academic research and industrial applications on speech recognition. By supporting the training & finetuning of the industrial-grade speech recognition model released on ModelScope, researchers and developers can conduct research and production of speech recognition models more conveniently, and promote the development of speech recognition ecology. ASR for Fun锛�
+## Export the model
+### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
-### Introduction
-- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
+```shell
+pip3 install torch torchaudio
+pip install -U modelscope
+pip install -U funasr
+```
-### Steps:
-1. Export the model.
- - Command: (`Tips`: torch >= 1.11.0 is required.)
+### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
- More details ref to ([export docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
+```shell
+python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize True
+```
- - `e.g.`, Export model from modelscope
- ```shell
- python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize False
- ```
- - `e.g.`, Export model from local path, the model'name must be `model.pb`.
- ```shell
- python -m funasr.export.export_model --model-name ./damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize False
- ```
-
-
-2. Install the `funasr_torch`.
+## Install the `funasr_torch`.
- install from pip
- ```shell
- pip install --upgrade funasr_torch -i https://pypi.Python.org/simple
- ```
- or install from source code
+install from pip
+```shell
+pip install -U funasr_torch
+# For the users in China, you could install with the command:
+# pip install -U funasr_torch -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+or install from source code
- ```shell
- git clone https://github.com/alibaba/FunASR.git && cd FunASR
- cd funasr/runtime/python/libtorch
- python setup.py build
- python setup.py install
- ```
+```shell
+git clone https://github.com/alibaba/FunASR.git && cd FunASR
+cd funasr/runtime/python/libtorch
+pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
-3. Run the demo.
- - Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`.
- - Input: wav formt file, support formats: `str, np.ndarray, List[str]`
- - Output: `List[str]`: recognition result.
- - Example:
- ```python
- from funasr_torch import Paraformer
+## Run the demo.
+- Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`.
+- Input: wav formt file, support formats: `str, np.ndarray, List[str]`
+- Output: `List[str]`: recognition result.
+- Example:
+ ```python
+ from funasr_torch import Paraformer
- model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
- model = Paraformer(model_dir, batch_size=1)
+ model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+ model = Paraformer(model_dir, batch_size=1)
- wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+ wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
- result = model(wav_path)
- print(result)
- ```
+ result = model(wav_path)
+ print(result)
+ ```
## Performance benchmark
diff --git a/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index e169087..9954daa 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -23,7 +23,6 @@
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
- pred_bias: int = 1,
quantize: bool = False,
intra_op_num_threads: int = 1,
):
@@ -48,7 +47,10 @@
self.batch_size = batch_size
self.device_id = device_id
self.plot_timestamp_to = plot_timestamp_to
- self.pred_bias = pred_bias
+ if "predictor_bias" in config['model_conf'].keys():
+ self.pred_bias = config['model_conf']['predictor_bias']
+ else:
+ self.pred_bias = 0
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
diff --git a/funasr/runtime/python/libtorch/setup.py b/funasr/runtime/python/libtorch/setup.py
index 1560950..fd8b151 100644
--- a/funasr/runtime/python/libtorch/setup.py
+++ b/funasr/runtime/python/libtorch/setup.py
@@ -15,10 +15,10 @@
setuptools.setup(
name='funasr_torch',
- version='0.0.3',
+ version='0.0.4',
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab, Alibaba Group, China",
+ author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license="The MIT License",
@@ -31,7 +31,7 @@
"PyYAML>=5.1.2", "torch-quant >= 0.4.0"],
packages=find_packages(include=["torch_paraformer*"]),
keywords=[
- 'funasr,paraformer, funasr_torch'
+ 'funasr, paraformer, funasr_torch'
],
classifiers=[
'Programming Language :: Python :: 3.6',
diff --git a/funasr/runtime/python/onnxruntime/README.md b/funasr/runtime/python/onnxruntime/README.md
index e19e3a2..e85e08a 100644
--- a/funasr/runtime/python/onnxruntime/README.md
+++ b/funasr/runtime/python/onnxruntime/README.md
@@ -1,31 +1,28 @@
-## Using funasr with ONNXRuntime
+# ONNXRuntime-python
+
+## Export the model
+### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
+
+```shell
+pip3 install torch torchaudio
+pip install -U modelscope
+pip install -U funasr
+```
+
+### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
+
+```shell
+python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
+```
-### Introduction
-- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
-
-
-### Steps:
-1. Export the model.
- - Command: (`Tips`: torch >= 1.11.0 is required.)
-
- More details ref to ([export docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
-
- - `e.g.`, Export model from modelscope
- ```shell
- python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize False
- ```
- - `e.g.`, Export model from local path, the model'name must be `model.pb`.
- ```shell
- python -m funasr.export.export_model --model-name ./damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize False
- ```
-
-
-2. Install the `funasr_onnx`
+## Install the `funasr_onnx`
install from pip
```shell
-pip install --upgrade funasr_onnx -i https://pypi.Python.org/simple
+pip install -U funasr_onnx
+# For the users in China, you could install with the command:
+# pip install -U funasr_onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
or install from source code
@@ -33,26 +30,27 @@
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/onnxruntime
-python setup.py build
-python setup.py install
+pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
-3. Run the demo.
- - Model_dir: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`.
- - Input: wav formt file, support formats: `str, np.ndarray, List[str]`
- - Output: `List[str]`: recognition result.
- - Example:
- ```python
- from funasr_onnx import Paraformer
+## Run the demo
+- Model_dir: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`.
+- Input: wav formt file, support formats: `str, np.ndarray, List[str]`
+- Output: `List[str]`: recognition result.
+- Example:
+ ```python
+ from funasr_onnx import Paraformer
- model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
- model = Paraformer(model_dir, batch_size=1)
+ model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+ model = Paraformer(model_dir, batch_size=1)
- wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+ wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
- result = model(wav_path)
- print(result)
- ```
+ result = model(wav_path)
+ print(result)
+ ```
## Performance benchmark
diff --git a/funasr/runtime/python/onnxruntime/demo_vad.py b/funasr/runtime/python/onnxruntime/demo_vad.py
deleted file mode 100644
index 2e17197..0000000
--- a/funasr/runtime/python/onnxruntime/demo_vad.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import soundfile
-from funasr_onnx import Fsmn_vad
-
-
-model_dir = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-wav_path = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav"
-model = Fsmn_vad(model_dir)
-
-#offline vad
-# result = model(wav_path)
-# print(result)
-
-#online vad
-speech, sample_rate = soundfile.read(wav_path)
-speech_length = speech.shape[0]
-
-sample_offset = 0
-step = 160 * 10
-param_dict = {'in_cache': []}
-for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
- if sample_offset + step >= speech_length - 1:
- step = speech_length - sample_offset
- is_final = True
- else:
- is_final = False
- param_dict['is_final'] = is_final
- segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
- param_dict=param_dict)
- print(segments_result)
-
diff --git a/funasr/runtime/python/onnxruntime/demo_vad_offline.py b/funasr/runtime/python/onnxruntime/demo_vad_offline.py
new file mode 100644
index 0000000..ea76ec3
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_vad_offline.py
@@ -0,0 +1,11 @@
+import soundfile
+from funasr_onnx import Fsmn_vad
+
+
+model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
+model = Fsmn_vad(model_dir)
+
+#offline vad
+result = model(wav_path)
+print(result)
diff --git a/funasr/runtime/python/onnxruntime/demo_vad_online.py b/funasr/runtime/python/onnxruntime/demo_vad_online.py
new file mode 100644
index 0000000..1ab4d9d
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_vad_online.py
@@ -0,0 +1,28 @@
+import soundfile
+from funasr_onnx import Fsmn_vad_online
+
+
+model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav"
+model = Fsmn_vad_online(model_dir)
+
+
+##online vad
+speech, sample_rate = soundfile.read(wav_path)
+speech_length = speech.shape[0]
+#
+sample_offset = 0
+step = 1600
+param_dict = {'in_cache': []}
+for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+ if sample_offset + step >= speech_length - 1:
+ step = speech_length - sample_offset
+ is_final = True
+ else:
+ is_final = False
+ param_dict['is_final'] = is_final
+ segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
+ param_dict=param_dict)
+ if segments_result:
+ print(segments_result)
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
index 86f0e8e..7d8d662 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,6 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .vad_bin import Fsmn_vad
+from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index cbdb8d9..e3ef8c7 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -23,7 +23,6 @@
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
- pred_bias: int = 1,
quantize: bool = False,
intra_op_num_threads: int = 4,
):
@@ -47,7 +46,10 @@
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = batch_size
self.plot_timestamp_to = plot_timestamp_to
- self.pred_bias = pred_bias
+ if "predictor_bias" in config['model_conf'].keys():
+ self.pred_bias = config['model_conf']['predictor_bias']
+ else:
+ self.pred_bias = 0
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 0dc728a..bbbb913 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -13,6 +13,11 @@
class CT_Transformer():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
@@ -119,6 +124,11 @@
class CT_Transformer_VadRealtime(CT_Transformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ https://arxiv.org/pdf/2003.01309.pdf
+ """
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
@@ -159,7 +169,7 @@
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
- "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
+ "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
}
try:
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
index 3f6c3d1..029f529 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -439,10 +439,9 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
def __call__(self, score: np.ndarray, waveform: np.ndarray,
- is_final: bool = False, max_end_sil: int = 800
+ is_final: bool = False, max_end_sil: int = 800, online: bool = False
):
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
@@ -457,20 +456,29 @@
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point:
- continue
- if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
- continue
- start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
- if self.output_data_buf[i].contain_seg_end_point:
- end_ms = self.output_data_buf[i].end_ms
- self.next_seg = True
- self.output_data_buf_offset += 1
+ if online:
+ if not self.output_data_buf[i].contain_seg_start_point:
+ continue
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+ if self.output_data_buf[i].contain_seg_end_point:
+ end_ms = self.output_data_buf[i].end_ms
+ self.next_seg = True
+ self.output_data_buf_offset += 1
+ else:
+ end_ms = -1
+ self.next_seg = False
else:
- end_ms = -1
- self.next_seg = False
+ if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
+ continue
+ start_ms = self.output_data_buf[i].start_ms
+ end_ms = self.output_data_buf[i].end_ms
+ self.output_data_buf_offset += 1
segment = [start_ms, end_ms]
segment_batch.append(segment)
+
if segment_batch:
segments.append(segment_batch)
if is_final:
@@ -605,3 +613,4 @@
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
index 11a8644..c92db4e 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
+import copy
import numpy as np
from typeguard import check_argument_types
@@ -153,6 +154,187 @@
cmvn = np.array([means, vars])
return cmvn
+
+class WavFrontendOnline(WavFrontend):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ # self.fbank_fn = knf.OnlineFbank(self.opts)
+ # add variables
+ self.frame_sample_length = int(self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000)
+ self.frame_shift_sample_length = int(self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000)
+ self.waveform = None
+ self.reserve_waveforms = None
+ self.input_cache = None
+ self.lfr_splice_cache = []
+
+ @staticmethod
+ # inputs has catted the cache
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
+ np.ndarray, np.ndarray, int]:
+ """
+ Apply lfr with data
+ """
+
+ LFR_inputs = []
+ T = inputs.shape[0] # include the right context
+ T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
+ splice_idx = T_lfr
+ for i in range(T_lfr):
+ if lfr_m <= T - i * lfr_n:
+ LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
+ else: # process last LFR frame
+ if is_final:
+ num_padding = lfr_m - (T - i * lfr_n)
+ frame = (inputs[i * lfr_n:]).reshape(-1)
+ for _ in range(num_padding):
+ frame = np.hstack((frame, inputs[-1]))
+ LFR_inputs.append(frame)
+ else:
+ # update splice_idx and break the circle
+ splice_idx = i
+ break
+ splice_idx = min(T - 1, splice_idx * lfr_n)
+ lfr_splice_cache = inputs[splice_idx:, :]
+ LFR_outputs = np.vstack(LFR_inputs)
+ return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
+
+ @staticmethod
+ def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
+ frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
+ return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
+
+
+ def fbank(
+ self,
+ input: np.ndarray,
+ input_lengths: np.ndarray
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ self.fbank_fn = knf.OnlineFbank(self.opts)
+ batch_size = input.shape[0]
+ if self.input_cache is None:
+ self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
+ input = np.concatenate((self.input_cache, input), axis=1)
+ frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
+ # update self.in_cache
+ self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
+ waveforms = np.empty(0, dtype=np.int16)
+ feats_pad = np.empty(0, dtype=np.float32)
+ feats_lens = np.empty(0, dtype=np.int32)
+ if frame_num:
+ waveforms = []
+ feats = []
+ feats_lens = []
+ for i in range(batch_size):
+ waveform = input[i]
+ waveforms.append(
+ waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
+ waveform = waveform * (1 << 15)
+
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
+ frames = self.fbank_fn.num_frames_ready
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
+ for i in range(frames):
+ mat[i, :] = self.fbank_fn.get_frame(i)
+ feat = mat.astype(np.float32)
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
+ feats.append(mat)
+ feats_lens.append(feat_len)
+
+ waveforms = np.stack(waveforms)
+ feats_lens = np.array(feats_lens)
+ feats_pad = np.array(feats)
+ self.fbanks = feats_pad
+ self.fbanks_lens = copy.deepcopy(feats_lens)
+ return waveforms, feats_pad, feats_lens
+
+ def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
+ return self.fbanks, self.fbanks_lens
+
+ def lfr_cmvn(
+ self,
+ input: np.ndarray,
+ input_lengths: np.ndarray,
+ is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
+ batch_size = input.shape[0]
+ feats = []
+ feats_lens = []
+ lfr_splice_frame_idxs = []
+ for i in range(batch_size):
+ mat = input[i, :input_lengths[i], :]
+ lfr_splice_frame_idx = -1
+ if self.lfr_m != 1 or self.lfr_n != 1:
+ # update self.lfr_splice_cache in self.apply_lfr
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
+ is_final)
+ if self.cmvn_file is not None:
+ mat = self.apply_cmvn(mat)
+ feat_length = mat.shape[0]
+ feats.append(mat)
+ feats_lens.append(feat_length)
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
+
+ feats_lens = np.array(feats_lens)
+ feats_pad = np.array(feats)
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
+
+
+ def extract_fbank(
+ self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ batch_size = input.shape[0]
+ assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
+ waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
+ if feats.shape[0]:
+ self.waveforms = waveforms if self.reserve_waveforms is None else np.concatenate(
+ (self.reserve_waveforms, waveforms), axis=1)
+ if not self.lfr_splice_cache:
+ for i in range(batch_size):
+ self.lfr_splice_cache.append(np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0))
+
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
+ lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
+ feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
+ feats_lengths += lfr_splice_cache_np[0].shape[0]
+ frame_from_waveforms = int(
+ (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
+ minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
+ feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(feats, feats_lengths, is_final)
+ if self.lfr_m == 1:
+ self.reserve_waveforms = None
+ else:
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
+ # print('frame_frame: ' + str(frame_from_waveforms))
+ self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
+ sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
+ self.waveforms = self.waveforms[:, :sample_length]
+ else:
+ # update self.reserve_waveforms and self.lfr_splice_cache
+ self.reserve_waveforms = self.waveforms[:,
+ :-(self.frame_sample_length - self.frame_shift_sample_length)]
+ for i in range(batch_size):
+ self.lfr_splice_cache[i] = np.concatenate((self.lfr_splice_cache[i], feats[i]), axis=0)
+ return np.empty(0, dtype=np.float32), feats_lengths
+ else:
+ if is_final:
+ self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
+ feats = np.stack(self.lfr_splice_cache)
+ feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
+ feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
+ if is_final:
+ self.cache_reset()
+ return feats, feats_lengths
+
+ def get_waveforms(self):
+ return self.waveforms
+
+ def cache_reset(self):
+ self.fbank_fn = knf.OnlineFbank(self.opts)
+ self.reserve_waveforms = None
+ self.input_cache = None
+ self.lfr_splice_cache = []
+
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
@@ -188,4 +370,4 @@
return feat, feat_len
if __name__ == '__main__':
- test()
\ No newline at end of file
+ test()
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index 221867d..ab8f041 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -11,13 +11,18 @@
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
-from .utils.frontend import WavFrontend
+from .utils.frontend import WavFrontend, WavFrontendOnline
from .utils.e2e_vad import E2EVadModel
logging = get_logger()
class Fsmn_vad():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
@@ -59,37 +64,48 @@
def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
- # waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
-
- param_dict = kwargs.get('param_dict', dict())
- is_final = param_dict.get('is_final', False)
- audio_in_cache = param_dict.get('audio_in_cache', None)
- audio_in_cum = audio_in
- if audio_in_cache is not None:
- audio_in_cum = np.concatenate((audio_in_cache, audio_in_cum))
- param_dict['audio_in_cache'] = audio_in_cum
- feats, feats_len = self.extract_feat([audio_in_cum])
-
- in_cache = param_dict.get('in_cache', list())
- in_cache = self.prepare_cache(in_cache)
- beg_idx = param_dict.get('beg_idx',0)
- feats = feats[:, beg_idx:beg_idx+8, :]
- param_dict['beg_idx'] = beg_idx + feats.shape[1]
- try:
- inputs = [feats]
- inputs.extend(in_cache)
- scores, out_caches = self.infer(inputs)
- param_dict['in_cache'] = out_caches
- segments = self.vad_scorer(scores, audio_in[None, :], is_final=is_final, max_end_sil=self.max_end_sil)
- # print(segments)
- if len(segments) == 1 and segments[0][0][1] != -1:
- self.frontend.reset_status()
+ waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
+ waveform_nums = len(waveform_list)
+ is_final = kwargs.get('kwargs', False)
+
+ segments = [[]] * self.batch_size
+ for beg_idx in range(0, waveform_nums, self.batch_size):
-
- except ONNXRuntimeError:
- logging.warning(traceback.format_exc())
- logging.warning("input wav is silence or noise")
- segments = []
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ waveform = waveform_list[beg_idx:end_idx]
+ feats, feats_len = self.extract_feat(waveform)
+ waveform = np.array(waveform)
+ param_dict = kwargs.get('param_dict', dict())
+ in_cache = param_dict.get('in_cache', list())
+ in_cache = self.prepare_cache(in_cache)
+ try:
+ t_offset = 0
+ step = int(min(feats_len.max(), 6000))
+ for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
+ if t_offset + step >= feats_len - 1:
+ step = feats_len - t_offset
+ is_final = True
+ else:
+ is_final = False
+ feats_package = feats[:, t_offset:int(t_offset + step), :]
+ waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
+
+ inputs = [feats_package]
+ # inputs = [feats]
+ inputs.extend(in_cache)
+ scores, out_caches = self.infer(inputs)
+ in_cache = out_caches
+ segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False)
+ # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
+
+ if segments_part:
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
+
+ except ONNXRuntimeError:
+ # logging.warning(traceback.format_exc())
+ logging.warning("input wav is silence or noise")
+ segments = ''
return segments
@@ -140,4 +156,125 @@
outputs = self.ort_infer(feats)
scores, out_caches = outputs[0], outputs[1:]
return scores, out_caches
-
\ No newline at end of file
+
+
+class Fsmn_vad_online():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ max_end_sil: int = None,
+ ):
+
+ if not Path(model_dir).exists():
+ raise FileNotFoundError(f'{model_dir} does not exist.')
+
+ model_file = os.path.join(model_dir, 'model.onnx')
+ if quantize:
+ model_file = os.path.join(model_dir, 'model_quant.onnx')
+ config_file = os.path.join(model_dir, 'vad.yaml')
+ cmvn_file = os.path.join(model_dir, 'vad.mvn')
+ config = read_yaml(config_file)
+
+ self.frontend = WavFrontendOnline(
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
+ )
+ self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.batch_size = batch_size
+ self.vad_scorer = E2EVadModel(config["vad_post_conf"])
+ self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+ self.encoder_conf = config["encoder_conf"]
+
+ def prepare_cache(self, in_cache: list = []):
+ if len(in_cache) > 0:
+ return in_cache
+ fsmn_layers = self.encoder_conf["fsmn_layers"]
+ proj_dim = self.encoder_conf["proj_dim"]
+ lorder = self.encoder_conf["lorder"]
+ for i in range(fsmn_layers):
+ cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
+ in_cache.append(cache)
+ return in_cache
+
+ def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
+ waveforms = np.expand_dims(audio_in, axis=0)
+
+ param_dict = kwargs.get('param_dict', dict())
+ is_final = param_dict.get('is_final', False)
+ feats, feats_len = self.extract_feat(waveforms, is_final)
+ segments = []
+ if feats.size != 0:
+ in_cache = param_dict.get('in_cache', list())
+ in_cache = self.prepare_cache(in_cache)
+ try:
+ inputs = [feats]
+ inputs.extend(in_cache)
+ scores, out_caches = self.infer(inputs)
+ param_dict['in_cache'] = out_caches
+ waveforms = self.frontend.get_waveforms()
+ segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
+ online=True)
+
+
+ except ONNXRuntimeError:
+ # logging.warning(traceback.format_exc())
+ logging.warning("input wav is silence or noise")
+ segments = []
+ return segments
+
+ def load_data(self,
+ wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(
+ f'The type of {wav_content} is not in [str, np.ndarray, list]')
+
+ def extract_feat(self,
+ waveforms: np.ndarray, is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
+ for idx, waveform in enumerate(waveforms):
+ waveforms_lens[idx] = waveform.shape[-1]
+
+ feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
+ # feats.append(feat)
+ # feats_len.append(feat_len)
+
+ # feats = self.pad_feats(feats, np.max(feats_len))
+ # feats_len = np.array(feats_len).astype(np.int32)
+ return feats.astype(np.float32), feats_len.astype(np.int32)
+
+ @staticmethod
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
+ return np.pad(feat, pad_width, 'constant', constant_values=0)
+
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+ feats = np.array(feat_res).astype(np.float32)
+ return feats
+
+ def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
+
+ outputs = self.ort_infer(feats)
+ scores, out_caches = outputs[0], outputs[1:]
+ return scores, out_caches
+
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 1a8ed7b..975b14d 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,14 +13,14 @@
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.0.3'
+VERSION_NUM = '0.0.5'
setuptools.setup(
name=MODULE_NAME,
version=VERSION_NUM,
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab, Alibaba Group, China",
+ author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license='MIT',
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 8d63b27..3d2004c 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -467,7 +467,7 @@
parser.add_argument(
"--batch_interval",
type=int,
- default=10000,
+ default=-1,
help="The batch interval for saving model.",
)
group.add_argument(
@@ -1587,6 +1587,8 @@
dest_sample_rate = args.frontend_conf["fs"]
else:
dest_sample_rate = 16000
+ else:
+ dest_sample_rate = 16000
dataset = ESPnetDataset(
iter_options.data_path_and_name_and_type,
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index e151473..d52c9c3 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,13 +38,16 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
@@ -121,6 +124,7 @@
asr=ESPnetASRModel,
uniasr=UniASR,
paraformer=Paraformer,
+ paraformer_online=ParaformerOnline,
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
@@ -150,6 +154,7 @@
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
+ chunk_conformer=ConformerChunkEncoder,
),
type_check=AbsEncoder,
default="rnn",
@@ -207,6 +212,16 @@
type_check=AbsDecoder,
default="rnn",
)
+
+rnnt_decoder_choices = ClassChoices(
+ "rnnt_decoder",
+ classes=dict(
+ rnnt=RNNTDecoder,
+ ),
+ type_check=RNNTDecoder,
+ default="rnnt",
+)
+
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
@@ -1331,3 +1346,378 @@
) -> Tuple[str, ...]:
retval = ("speech", "text")
return retval
+
+
+class ASRTransducerTask(AbsTask):
+ """ASR Transducer Task definition."""
+
+ num_optimizers: int = 1
+
+ class_choices_list = [
+ frontend_choices,
+ specaug_choices,
+ normalize_choices,
+ encoder_choices,
+ rnnt_decoder_choices,
+ ]
+
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ """Add Transducer task arguments.
+ Args:
+ cls: ASRTransducerTask object.
+ parser: Transducer arguments parser.
+ """
+ group = parser.add_argument_group(description="Task related.")
+
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="Integer-string mapper for tokens.",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of dimensions for input features.",
+ )
+ group.add_argument(
+ "--init",
+ type=str_or_none,
+ default=None,
+ help="Type of model initialization to use.",
+ )
+ group.add_argument(
+ "--model_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(TransducerModel),
+ help="The keyword arguments for the model class.",
+ )
+ # group.add_argument(
+ # "--encoder_conf",
+ # action=NestedDictAction,
+ # default={},
+ # help="The keyword arguments for the encoder class.",
+ # )
+ group.add_argument(
+ "--joint_network_conf",
+ action=NestedDictAction,
+ default={},
+ help="The keyword arguments for the joint network class.",
+ )
+ group = parser.add_argument_group(description="Preprocess related.")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Whether to apply preprocessing to input data.",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word", "phn"],
+ help="The type of tokens to use during tokenization.",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The path of the sentencepiece model.",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="The 'non_linguistic_symbols' file path.",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Text cleaner to use.",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="g2p method to use if --token_type=phn.",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Normalization value for maximum amplitude scaling.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The RIR SCP file path.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability of the applied RIR convolution.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The path of noise SCP file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability of the applied noise addition.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of the noise decibel level.",
+ )
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --decoder and --decoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ """Build collate function.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ train: Training mode.
+ Return:
+ : Callable collate function.
+ """
+ assert check_argument_types()
+
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ """Build pre-processing function.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ train: Training mode.
+ Return:
+ : Callable pre-processing function.
+ """
+ assert check_argument_types()
+
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ """Required data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ train: Training mode.
+ inference: Inference mode.
+ Return:
+ retval: Required task data.
+ """
+ if not inference:
+ retval = ("speech", "text")
+ else:
+ retval = ("speech",)
+
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ """Optional data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ train: Training mode.
+ inference: Inference mode.
+ Return:
+ retval: Optional task data.
+ """
+ retval = ()
+ assert check_return_type(retval)
+
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace) -> TransducerModel:
+ """Required data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ Return:
+ model: ASR Transducer model.
+ """
+ assert check_argument_types()
+
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size }")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Encoder
+
+ if getattr(args, "encoder", None) is not None:
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size, **args.encoder_conf)
+ else:
+ encoder = Encoder(input_size, **args.encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ # 5. Decoder
+ rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+ decoder = rnnt_decoder_class(
+ vocab_size,
+ **args.rnnt_decoder_conf,
+ )
+ decoder_output_size = decoder.output_size
+
+ if getattr(args, "decoder", None) is not None:
+ att_decoder_class = decoder_choices.get_class(args.att_decoder)
+
+ att_decoder = att_decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+ else:
+ att_decoder = None
+ # 6. Joint Network
+ joint_network = JointNetwork(
+ vocab_size,
+ encoder_output_size,
+ decoder_output_size,
+ **args.joint_network_conf,
+ )
+
+ # 7. Build model
+
+ if encoder.unified_model_training:
+ model = UnifiedTransducerModel(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ else:
+ model = TransducerModel(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ # 8. Initialize model
+ if args.init is not None:
+ raise NotImplementedError(
+ "Currently not supported.",
+ "Initialization part will be reworked in a short future.",
+ )
+
+ #assert check_return_type(model)
+
+ return model
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 096a5c8..45e4ce7 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -1,3 +1,11 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+https://arxiv.org/abs/2211.10243
+TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+https://arxiv.org/abs/2303.05397
+"""
+
import argparse
import logging
import os
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index bef5dc5..9710447 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -1,3 +1,7 @@
+"""
+Author: Speech Lab, Alibaba Group, China
+"""
+
import argparse
import logging
import os
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index b12bded..9574a0d 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -571,8 +571,7 @@
#ouput dir
output_dir = Path(options.output_dir)
#batch interval
- batch_interval = options.batch_interval
- assert batch_interval > 0
+ batch_interval = options.batch_interval
start_time = time.perf_counter()
for iiter, (_, batch) in enumerate(
@@ -580,14 +579,17 @@
):
assert isinstance(batch, dict), type(batch)
- if rank == 0:
+ if batch_interval > 0 and (not distributed_option.distributed or rank == 0):
if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
- if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
- buffer = BytesIO()
- torch.save(model.state_dict(), buffer)
- options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue())
-
+ if num_batch_updates % batch_interval == 0:
+ if options.use_pai and options.oss_bucket is not None:
+ buffer = BytesIO()
+ torch.save(model.state_dict(), buffer)
+ options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
+ else:
+ torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
+
if distributed:
torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
if iterator_stop > 0:
diff --git a/funasr/version.txt b/funasr/version.txt
index 1d0ba9e..267577d 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.4.0
+0.4.1
diff --git a/setup.py b/setup.py
index c3eed88..6deaf9c 100644
--- a/setup.py
+++ b/setup.py
@@ -123,7 +123,7 @@
name="funasr",
version=version,
url="https://github.com/alibaba-damo-academy/FunASR.git",
- author="Speech Lab, Alibaba Group, China",
+ author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index b3c5a24..2f2f11d 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -43,6 +43,7 @@
rec_result = inference_pipeline(
audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "姣忎竴澶╅兘瑕佸揩涔愬枖"
def test_paraformer(self):
inference_pipeline = pipeline(
@@ -51,6 +52,7 @@
rec_result = inference_pipeline(
audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "姣忎竴澶╅兘瑕佸揩涔愬枖"
class TestMfccaInferencePipelines(unittest.TestCase):
diff --git a/tests/test_punctuation_pipeline.py b/tests/test_punctuation_pipeline.py
index 52be9bb..e582042 100644
--- a/tests/test_punctuation_pipeline.py
+++ b/tests/test_punctuation_pipeline.py
@@ -26,16 +26,14 @@
inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
- model_revision="v1.0.0",
)
inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
vads = inputs.split("|")
- cache_out = []
rec_result_all = "outputs:"
+ param_dict = {"cache": []}
for vad in vads:
- rec_result = inference_pipeline(text_in=vad, cache=cache_out)
- cache_out = rec_result['cache']
- rec_result_all += rec_result['text']
+ rec_result = inference_pipeline(text_in=vad, param_dict=param_dict)
+ rec_result_all += rec_result["text"]
logger.info("punctuation inference result: {0}".format(rec_result_all))
--
Gitblit v1.9.1