游雁
2023-02-16 544b798b32819fe2ffed1fccb44e8c2620449a53
Merge branch 'dev_gzf' of github.com:alibaba-damo-academy/FunASR into dev_gzf
add
69个文件已修改
30个文件已添加
4674 ■■■■ 已修改文件
.github/workflows/main.yml 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 38 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/build_task.md 106 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/get_started.md 43 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/dingding.jpg 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/wechat.png 补丁 | 查看 | 原始文档 | blame | 历史
docs/index.rst 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/installation.md 30 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/modelscope_usages.md 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs_cn/build_task.md 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs_cn/get_started.md 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs_cn/index.rst 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs_cn/modelscope_usages.md 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/diarization/sond/README.md 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py 103 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py 67 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_mfcca.py 764 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py 45 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_timestamp.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr.py 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr_vad.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference.py 48 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_launch.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/iterable_dataset.py 156 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/preprocessor.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/README.md 48 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_asr_paraformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/predictor/cif.py 50 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_mfcca.py 322 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_vad.py 117 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/encoder_layer_mfcca.py 270 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/fsmn_encoder.py 168 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/mfcca_encoder.py 450 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/default.py 125 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/README.md 75 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/demo.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py 144 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/requirements.txt 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/frontend.py 136 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/postprocess_utils.py 240 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/utils.py 256 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/torchscripts/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/torchscripts/paraformer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 138 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/vad.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 82 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/wav_utils.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.github/workflows/main.yml
@@ -5,7 +5,7 @@
      - main
  push:
    branches:
      - dev
      - dev_wjm
jobs:
  docs:
@@ -14,18 +14,27 @@
      - uses: actions/checkout@v1
      - uses: ammaraskar/sphinx-action@master
        with:
          docs-folder: "docs/"
          pre-build-command: "pip install sphinx-markdown-tables nbsphinx jinja2 recommonmark sphinx_rtd_theme"
      - uses: ammaraskar/sphinx-action@master
        with:
          docs-folder: "docs_cn/"
          pre-build-command: "pip install sphinx-markdown-tables nbsphinx jinja2 recommonmark sphinx_rtd_theme"
      - name: deploy copy
        if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev'
        if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev_wjm'
        run: |
          mkdir public
          touch public/.nojekyll
          cp -r docs_cn/_build/html/* public/
          mkdir public/en
          touch public/en/.nojekyll
          cp -r docs/_build/html/* public/en/
          mkdir public/cn
          touch public/cn/.nojekyll
          cp -r docs_cn/_build/html/* public/cn/
      - name: deploy github.io pages
        if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev'
        if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev_wjm'
        uses: peaceiris/actions-gh-pages@v2.3.1
        env:
          GITHUB_TOKEN: ${{ secrets.ACCESS_TOKEN }}
.gitignore
@@ -6,3 +6,4 @@
.DS_Store
init_model/
*.tar.gz
test_local/
README.md
@@ -1,4 +1,4 @@
<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>
[//]: # (<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>)
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
@@ -7,7 +7,8 @@
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new) 
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/index.html)
| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
@@ -15,14 +16,31 @@
| [**Contact**](#contact)
## What's new: 
### 2023.1.16, funasr-0.1.6
### 2023.2.17, funasr-0.2.0, modelscope-1.3.0
- We support a new feature, export paraformer models into [onnx and torchscripts](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) from modelscope. The local finetuned models are also supported.
- We support a new feature, [onnxruntime](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer), you could deploy the runtime without modelscope or funasr, for the [paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) model, the rtf of onnxruntime is 3x speedup(0.110->0.038) on cpu, [details](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer#speed).
- We support a new feature, [grpc](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc), you could build the ASR service with grpc, by deploying the modelscope pipeline or onnxruntime.
- We release a new model [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary), which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords.
- We optimize the timestamp alignment of [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), the prediction accuracy of timestamp is much improved, and achieving accumulated average shift (aas) of 74.7ms, [details](https://arxiv.org/abs/2301.12343).
- We release a new model, [8k VAD model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [modelscope](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
- We release a new model, [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary), a multi-channel multi-speaker model which is independent of the number and geometry of microphones and supports Mandarin meeting transcription.
- We release several new UniASR model:
[Southern Fujian Dialect model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary),
[French model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/summary),
[German model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/summary),
[Vietnamese model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/summary),
[Persian model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/summary).
- We release a new model, [paraformer-data2vec model](https://www.modelscope.cn/models/damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/summary), an unsupervised pretraining model on AISHELL-2, which is inited for paraformer model and then finetune on AISHEL-1.
- Various new types of audio input types are now supported by modelscope inference pipeline, including: mp3、flac、ogg、opus...
### 2023.1.16, funasr-0.1.6, modelscope-1.2.0
- We release a new version model [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), which integrate the [VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) model, [ASR](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary),
 [Punctuation](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) model and timestamp together. The model could take in several hours long inputs.
- We release a new type model, [VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [Model Zoo](docs/modelscope_models.md).
- We release a new type model, [Punctuation](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary), which could predict the punctuation of ASR models's results. It could be freely integrated with any ASR models in [Model Zoo](docs/modelscope_models.md).
- We release a new model, [16k VAD model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [modelscope](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
- We release a new model, [Punctuation](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary), which could predict the punctuation of ASR models's results. It could be freely integrated with any ASR models in [Model Zoo](docs/modelscope_models.md).
- We release a new model, [Data2vec](https://www.modelscope.cn/models/damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch/summary), an unsupervised pretraining model which could be finetuned on ASR and other downstream tasks.
- We release a new model, [Paraformer-Tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary), a lightweight Paraformer model which supports Mandarin command words recognition.
- We release a new type model, [SV](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary), which could extract speaker embeddings and further perform speaker verification on paired utterances. It will be supported for speaker diarization in the future version.
- We release a new model, [SV](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary), which could extract speaker embeddings and further perform speaker verification on paired utterances. It will be supported for speaker diarization in the future version.
- We improve the pipeline of modelscope to speedup the inference, by integrating the process of build model into build pipeline.
- Various new types of audio input types are now supported by modelscope inference pipeline, including wav.scp, wav format, audio bytes, wave samples...
@@ -42,7 +60,7 @@
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
## Usage
For users who are new to FunASR and ModelScope, please refer to [FunASR Docs](https://alibaba-damo-academy.github.io/FunASR/index.html).
For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
## Contact
@@ -85,4 +103,10 @@
  booktitle={INTERSPEECH},
  year={2022}
}
@inproceedings{Shi2023AchievingTP,
  title={Achieving Timestamp Prediction While Recognizing with Non-Autoregressive End-to-End ASR Model},
  author={Xian Shi and Yanni Chen and Shiliang Zhang and Zhijie Yan},
  booktitle={arXiv preprint arXiv:2301.12343}
  year={2023}
}
```
docs/build_task.md
New file
@@ -0,0 +1,106 @@
# Build custom tasks
FunASR is similar to ESPNet, which applies `Task`  as the general interface ti achieve the training and inference of models. Each `Task` is a class inherited from `AbsTask` and its corresponding code can be seen in `funasr/tasks/abs_task.py`. The main functions of `AbsTask` are shown as follows:
```python
class AbsTask(ABC):
    @classmethod
    def add_task_arguments(cls, parser: argparse.ArgumentParser):
        pass
    @classmethod
    def build_preprocess_fn(cls, args, train):
        (...)
    @classmethod
    def build_collate_fn(cls, args: argparse.Namespace):
        (...)
    @classmethod
    def build_model(cls, args):
        (...)
    @classmethod
    def main(cls, args):
        (...)
```
- add_task_arguments:Add parameters required by a specified `Task`
- build_preprocess_fn:定义如何处理对样本进行预处理 define how to preprocess samples
- build_collate_fn:define how to combine multiple samples into a `batch`
- build_model:define the model
- main:training interface, starting training through `Task.main()`
Next, we take the speech recognition as an example to introduce how to define a new `Task`. For the corresponding code, please see `ASRTask` in `funasr/tasks/asr.py`. The procedure of defining a new `Task` is actually the procedure of redefining the above functions according to the requirements of the specified `Task`.
- add_task_arguments
```python
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
    group = parser.add_argument_group(description="Task related")
    group.add_argument(
        "--token_list",
        type=str_or_none,
        default=None,
        help="A text mapping int-id to token",
    )
    (...)
```
For speech recognition tasks, specific parameters required include `token_list`, etc. According to the specific requirements of different tasks, users can define corresponding parameters in this function.
- build_preprocess_fn
```python
@classmethod
def build_preprocess_fn(cls, args, train):
    if args.use_preprocessor:
        retval = CommonPreprocessor(
                    train=train,
                    token_type=args.token_type,
                    token_list=args.token_list,
                    bpemodel=args.bpemodel,
                    non_linguistic_symbols=args.non_linguistic_symbols,
                    text_cleaner=args.cleaner,
                    ...
                )
    else:
        retval = None
    return retval
```
This function defines how to preprocess samples. Specifically, the input of speech recognition tasks includes speech and text. For speech, functions such as (optional) adding noise and reverberation to the speech are supported. For text, functions such as (optional) processing text according to bpe and mapping text to `tokenid` are supported. Users can choose the preprocessing operation that needs to be performed on the sample. For the detail implementation, please refer to `CommonPreprocessor`.
- build_collate_fn
```python
@classmethod
def build_collate_fn(cls, args, train):
    return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
```
This function defines how to combine multiple samples into a `batch`. For speech recognition tasks, `padding` is employed to obtain equal-length data from different speech and text. Specifically, we set `0.0` as the default padding value for speech and `-1` as the default padding value for text. Users can define different `batch` operations here. For the detail implementation, please refer to `CommonCollateFn`.
- build_model
```python
@classmethod
def build_model(cls, args, train):
    with open(args.token_list, encoding="utf-8") as f:
        token_list = [line.rstrip() for line in f]
        vocab_size = len(token_list)
        frontend = frontend_class(**args.frontend_conf)
        specaug = specaug_class(**args.specaug_conf)
        normalize = normalize_class(**args.normalize_conf)
        preencoder = preencoder_class(**args.preencoder_conf)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        postencoder = postencoder_class(input_size=encoder_output_size, **args.postencoder_conf)
        decoder = decoder_class(vocab_size=vocab_size, encoder_output_size=encoder_output_size,  **args.decoder_conf)
        ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf)
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            postencoder=postencoder,
            decoder=decoder,
            ctc=ctc,
            token_list=token_list,
            **args.model_conf,
        )
    return model
```
This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `AbsESPnetModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
docs/get_started.md
@@ -1,21 +1,21 @@
# Get Started
This is an easy example which introduces how to train a paraformer model on AISHELL-1 data from scratch. According to this example, you can train other models (conformer, paraformer, etc.) on other datasets (AISHELL-1, AISHELL-2, etc.) similarly.
Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
## Overall Introduction
We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 data  . This recipe consists of five stages and support training on multiple GPUs and decoding by CPU or GPU. Before introduce each stage in detail, we first explain several variables which should be set by users.
We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
- `CUDA_VISIBLE_DEVICES`: visible gpu list
- `gpu_num`: the number of GPUs used for training
- `gpu_inference`: whether to use GPUs for decoding
- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU.
- `feats_dir`: the path to save processed data
- `exp_dir`: the path to save experimental results
- `data_aishell`: the path of raw AISHELL-1 data
- `tag`: the suffix of experimental result directory
- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
- `data_aishell`: the raw path of AISHELL-1 dataset
- `feats_dir`: the path for saving processed data
- `nj`: the number of jobs for data preparation
- `speed_perturb`: the range of speech perturbed
- `exp_dir`: the path for saving experimental results
- `tag`: the suffix of experimental result directory
## Stage 0: Data preparation
This stage processes raw AISHELL-1 data `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx` and `xxx` means `train/dev/test`. Here we assume you have already downloaded AISHELL-1 data. If not, you can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. Here we show examples for `wav.scp` and `text`, separately.
This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
* `wav.scp`
```
BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
@@ -30,17 +30,17 @@
BAC009S0002W0124 自 六 月 底 呼 和 浩 特 市 率 先 宣 布 取 消 限 购 后
...
```
We can see that these two files both have two columns while the first column is the wav-id and the second column is the corresponding wav-path/label tokens.
These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
## Stage 1: Feature Generation
This stage extracts FBank feature from raw wav `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. You can set `nj` to control the number of jobs for feature generation. The output features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
* `feats.scp`
```
...
BAC009S0002W0122_sp0.9 /nfs/haoneng.lhn/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
...
```
Note that samples in this file have already been shuffled. This file contains two columns. The first column is the wav-id while the second column is the kaldi-ark feature path. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
* `speech_shape`
```
...
@@ -53,10 +53,10 @@
BAC009S0002W0122_sp0.9 15
...
```
These two files have two columns. The first column is the wav-id and the second column is the corresponding speech feature shape and text length.
These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
## Stage 2: Dictionary Preparation
This stage prepares a dictionary, which is used as a mapping between label characters and integer indices during ASR training. The output dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. Here we show an example of `tokens.txt` as follows:
This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
* `tokens.txt`
```
<blank>
@@ -75,7 +75,7 @@
* `<unk>`: indicates the out-of-vocabulary token
## Stage 3: Training
This stage achieves the training of the specified model. To start training, you should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
* DDP Training
@@ -83,30 +83,29 @@
* DataLoader
[comment]: <> (We support two types of DataLoaders for small and large datasets, respectively. By default, the small DataLoader is used and you can set `dataset_type=large` to enable large DataLoader. For small DataLoader, )
We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and you can set `dataset_type=large` to enable it.
We support an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large dataset and users can set `dataset_type=large` to enable it.
* Configuration
The parameters of the training, including model, optimization, dataset, etc., are specified by a YAML file in `conf` directory. Also, you can directly specify the parameters in `run.sh` recipe. Please avoid to specify the same parameters in both the YAML file and the recipe.
The parameters of the training, including model, optimization, dataset, etc., can be set by a YAML file in `conf` directory. Also, users can directly set the parameters in `run.sh` recipe. Please avoid to set the same parameters in both the YAML file and the recipe.
* Training Steps
We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of the two parameters, the training will be stopped.
We support two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
* Tensorboard
You can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
Users can use tensorboard to observe the loss, learning rate, etc. Please run the following command:
```
tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
```
## Stage 4: Decoding
This stage generates the recognition results with acoustic features as input and calculate the `CER` to verify the performance of the trained model.
This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
* Mode Selection
As we support conformer, paraformer and uniasr in FunASR and they have different inference interfaces, a `mode` param is specified as `asr/paraformer/uniase` according to the trained model.
As we support paraformer, uniasr, conformer and other models in FunASR, a `mode` parameter should be specified as `asr/paraformer/uniasr` according to the trained model.
* Configuration
docs/images/dingding.jpg

docs/images/wechat.png

docs/index.rst
@@ -16,12 +16,14 @@
   ./installation.md
   ./papers.md
   ./get_started.md
   ./build_task.md
.. toctree::
   :maxdepth: 1
   :caption: ModelScope:
   ./modelscope_models.md
   ./modelscope_usages.md
Indices and tables
==================
docs/installation.md
@@ -1,35 +1,35 @@
# Installation
FunASR is easy to install, which is mainly based on python packages.
FunASR is easy to install. The detailed installation steps are as follows:
- Clone the repo
``` sh
git clone https://github.com/alibaba/FunASR.git
```
- Install Conda
``` sh
- Install Conda and create virtual environment:
```sh
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
sh Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc
conda create -n funasr python=3.7
conda activate funasr
```
- Install Pytorch (version >= 1.7.0):
| cuda  | |
|:-----:| --- |
|  9.2  | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=9.2 -c pytorch |
| 10.2  | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch |
| 11.1  | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch |
```sh
pip install torch torchaudio
```
For more versions, please see [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally)
- Install ModelScope
For users in China, you can configure the following mirror source to speed up the downloading:
``` sh
pip config set global.index-url https://mirror.sjtu.edu.cn/pypi/web/simple
```
Install or update ModelScope
```sh
pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
```
- Install other packages
- Clone the repo and install other packages
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install --editable ./
```
docs/modelscope_usages.md
New file
@@ -0,0 +1,53 @@
# ModelScope Usage
ModelScope is an open-source model-as-service platform supported by Alibaba, which provides flexible and convenient model applications for users in academia and industry. For specific usages and open source models, please refer to [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition). In the domain of speech, we provide autoregressive/non-autoregressive speech recognition, speech pre-training, punctuation prediction and other models, which are convenient for users.
## Overall Introduction
We provide the usages of different models under the `egs_modelscope`, which supports directly employing our provided models for inference, as well as finetuning the models we provided as pre-trained initial models. Next, we will introduce the model provided in the `egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch` directory, including `infer.py`, `finetune.py` and `infer_after_finetune .py`. The corresponding functions are as follows:
- `infer.py`: perform inference on the specified dataset based on our provided model
- `finetune.py`: employ our provided model as the initial model for fintuning
- `infer_after_finetune.py`: perform inference on the specified dataset based on the finetuned model
## Inference
We provide `infer.py` to achieve the inference. Based on this file, users can preform inference on the specified dataset based on our provided model and obtain the corresponding recognition results. If the transcript is given, the `CER` will be calculated at the same time. Before performing inference, users can set the following parameters to modify the inference configuration:
* `data_dir`:dataset directory. The directory should contain the wav list file `wav.scp` and the transcript file `text` (optional). For the format of these two files, please refer to the instructions in [Quick Start](./get_started.md). If the `text` file exists, the CER will be calculated accordingly, otherwise it will be skipped.
* `output_dir`:the directory for saving the inference results
* `batch_size`:batch size during the inference
* `ctc_weight`:some models contain a CTC module, users can set this parameter to specify the weight of the CTC module during the inference
In addition to directly setting parameters in `infer.py`, users can also manually set the parameters in the `decoding.yaml` file in the model download directory to modify the inference configuration.
## Finetuning
We provide `finetune.py` to achieve the finetuning. Based on this file, users can finetune on the specified dataset based on our provided model as the initial model to achieve better performance in the specificed domain. Before finetuning, users can set the following parameters to modify the finetuning configuration:
* `data_path`:dataset directory。This directory should contain the `train` directory for saving the training set and the `dev` directory for saving the validation set. Each directory needs to contain the wav list file `wav.scp` and the transcript file `text`
* `output_dir`:the directory for saving the finetuning results
* `dataset_type`:for small dataset,set as `small`;for dataset larger than 1000 hours,set as `large`
* `batch_bins`:batch size,if dataset_type is set as `small`,the unit of batch_bins is the number of fbank feature frames; if dataset_type is set as `large`, the unit of batch_bins is milliseconds
* `max_epoch`:the maximum number of training epochs
The following parameters can also be set. However, if there is no special requirement, users can ignore these parameters and use the default value we provided directly:
* `accum_grad`:the accumulation of the gradient
* `keep_nbest_models`:select the `keep_nbest_models` models with the best performance and average the parameters
  of these models to get a better model
* `optim`:set the optimizer
* `lr`:set the learning rate
* `scheduler`:set learning rate adjustment strategy
* `scheduler_conf`:set the related parameters of the learning rate adjustment strategy
* `specaug`:set for the spectral augmentation
* `specaug_conf`:set related parameters of the spectral augmentation
In addition to directly setting parameters in `finetune.py`, users can also manually set the parameters in the `finetune.yaml` file in the model download directory to modify the finetuning configuration.
## Inference after Finetuning
We provide `infer_after_finetune.py` to achieve the inference based on the model finetuned by users. Based on this file, users can preform inference on the specified dataset based on the finetuned model and obtain the corresponding recognition results. If the transcript is given, the `CER` will be calculated at the same time. Before performing inference, users can set the following parameters to modify the inference configuration:
* `data_dir`:dataset directory。The directory should contain the wav list file `wav.scp` and the transcript file `text` (optional). If the `text` file exists, the CER will be calculated accordingly, otherwise it will be skipped.
* `output_dir`:the directory for saving the inference results
* `batch_size`:batch size during the inference
* `ctc_weight`:some models contain a CTC module, users can set this parameter to specify the weight of the CTC module during the inference
* `decoding_model_name`:set the name of the model used for the inference
The following parameters can also be set. However, if there is no special requirement, users can ignore these parameters and use the default value we provided directly:
* `modelscope_model_name`:the initial model name used when finetuning
* `required_files`:files required for the inference when using the modelscope interface
## Announcements
Some models may have other specific parameters during the finetuning and inference. The usages of these parameters can be found in the `README.md` file in the corresponding directory.
docs_cn/build_task.md
New file
@@ -0,0 +1,105 @@
# 搭建自定义任务
FunASR类似ESPNet,以`Task`为通用接口,从而实现模型的训练和推理。每一个`Task`是一个类,其需要继承`AbsTask`,其对应的具体代码见`funasr/tasks/abs_task.py`。下面给出其包含的主要函数及功能介绍:
```python
class AbsTask(ABC):
    @classmethod
    def add_task_arguments(cls, parser: argparse.ArgumentParser):
        pass
    @classmethod
    def build_preprocess_fn(cls, args, train):
        (...)
    @classmethod
    def build_collate_fn(cls, args: argparse.Namespace):
        (...)
    @classmethod
    def build_model(cls, args):
        (...)
    @classmethod
    def main(cls, args):
        (...)
```
- add_task_arguments:添加特定`Task`需要的参数
- build_preprocess_fn:定义如何处理对样本进行预处理
- build_collate_fn:定义如何将多个样本组成一个`batch`
- build_model:定义模型
- main:训练入口,通过`Task.main()`来启动训练
下面我们将以语音识别任务为例,介绍如何定义一个新的`Task`,具体代码见`funasr/tasks/asr.py`中的`ASRTask`。 定义新的`Task`的过程,其实就是根据任务需求,重定义上述函数的过程。
- add_task_arguments
```python
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
    group = parser.add_argument_group(description="Task related")
    group.add_argument(
        "--token_list",
        type=str_or_none,
        default=None,
        help="A text mapping int-id to token",
    )
    (...)
```
对于语音识别任务,需要的特定参数包括`token_list`等。根据不同任务的特定需求,用户可以在此函数中定义相应的参数。
- build_preprocess_fn
```python
@classmethod
def build_preprocess_fn(cls, args, train):
    if args.use_preprocessor:
        retval = CommonPreprocessor(
                    train=train,
                    token_type=args.token_type,
                    token_list=args.token_list,
                    bpemodel=args.bpemodel,
                    non_linguistic_symbols=args.non_linguistic_symbols,
                    text_cleaner=args.cleaner,
                    ...
                )
    else:
        retval = None
    return retval
```
该函数定义了如何对样本进行预处理。具体地,语音识别任务的输入包括音频和抄本。对于音频,在此实现了(可选)对音频加噪声,加混响等功能;对于抄本,在此实现了(可选)根据bpe处理抄本,将抄本映射成`tokenid`等功能。用户可以自己选择需要对样本进行的预处理操作,实现方法可以参考`CommonPreprocessor`。
- build_collate_fn
```python
@classmethod
def build_collate_fn(cls, args, train):
    return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
```
该函数定义了如何将多个样本组成一个`batch`。对于语音识别任务,在此实现的是将不同的音频和抄本,通过`padding`的方式来得到等长的数据。具体地,我们默认用`0.0`来作为音频的填充值,用`-1`作为抄本的默认填充值。用户可以在此定义不同的组`batch`操作,实现方法可以参考`CommonCollateFn`。
- build_model
```python
@classmethod
def build_model(cls, args, train):
    with open(args.token_list, encoding="utf-8") as f:
        token_list = [line.rstrip() for line in f]
        vocab_size = len(token_list)
        frontend = frontend_class(**args.frontend_conf)
        specaug = specaug_class(**args.specaug_conf)
        normalize = normalize_class(**args.normalize_conf)
        preencoder = preencoder_class(**args.preencoder_conf)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        postencoder = postencoder_class(input_size=encoder_output_size, **args.postencoder_conf)
        decoder = decoder_class(vocab_size=vocab_size, encoder_output_size=encoder_output_size,  **args.decoder_conf)
        ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf)
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            postencoder=postencoder,
            decoder=decoder,
            ctc=ctc,
            token_list=token_list,
            **args.model_conf,
        )
    return model
```
该函数定义了具体的模型。对于不同的语音识别模型,往往可以共用同一个语音识别`Task`,额外需要做的是在此函数中定义特定的模型。例如,这里给出的是一个标准的encoder-decoder结构的语音识别模型。具体地,先定义该模型的各个模块,包括encoder,decoder等,然后在将这些模块组合在一起得到一个完整的模型。在FunASR中,模型需要继承`AbsESPnetModel`,其具体代码见`funasr/train/abs_espnet_model.py`,主要需要实现的是`forward`函数。
docs_cn/get_started.md
@@ -106,7 +106,8 @@
本阶段用于解码得到识别结果,同时计算CER来验证训练得到的模型性能。
* Mode Selection
由于我们提供了paraformer,uniasr和conformer等模型,因此在解码时,需要指定相应的解码模式。对应的参数为`mode`,相应的可选设置为`asr/paraformer/uniase`等。
由于我们提供了paraformer,uniasr和conformer等模型,因此在解码时,需要指定相应的解码模式。对应的参数为`mode`,相应的可选设置为`asr/paraformer/uniasr`等。
* Configuration
docs_cn/index.rst
@@ -16,6 +16,7 @@
   ./installation.md
   ./papers.md
   ./get_started.md
   ./build_task.md
.. toctree::
   :maxdepth: 1
docs_cn/modelscope_usages.md
@@ -1,14 +1,14 @@
# 快速使用ModelScope
# ModelScope Usage
ModelScope是阿里巴巴推出的开源模型即服务共享平台,为广大学术界用户和工业界用户提供灵活、便捷的模型应用支持。具体的使用方法和开源模型可以参见[ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition) 。在语音方向,我们提供了自回归/非自回归语音识别,语音预训练,标点预测等模型,用户可以方便使用。
## 整体介绍
我们在egs_modelscope目录下提供了相关模型的使用,支持直接用我们提供的模型进行推理,同时也支持将我们提供的模型作为预训练好的模型作为初始模型进行微调。下面,我们将以egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch目录中提供的模型来进行介绍,包括`infer.py`,`finetune.py`和`infer_after_finetune.py`,对应的功能如下:
我们在`egs_modelscope` 目录下提供了不同模型的使用方法,支持直接用我们提供的模型进行推理,同时也支持将我们提供的模型作为预训练好的初始模型进行微调。下面,我们将以`egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch`目录中提供的模型来进行介绍,包括`infer.py`,`finetune.py`和`infer_after_finetune.py`,对应的功能如下:
- `infer.py`: 基于我们提供的模型,对指定的数据集进行推理
- `finetune.py`: 将我们提供的模型作为初始模型进行微调
- `infer_after_finetune.py`: 基于微调得到的模型,对指定的数据集进行推理
## 模型推理
我们提供了`infer.py`来实现模型推理。基于此文件,用户可以基于我们提供的模型,对指定的数据集进行推理,得到相应的识别结果。如果同时给定了抄本,则会同时计算CER。在开始推理前,用户可以指定如下参数来修改推理配置:
我们提供了`infer.py`来实现模型推理。基于此文件,用户可以基于我们提供的模型,对指定的数据集进行推理,得到相应的识别结果。如果给定了抄本,则会同时计算`CER`。在开始推理前,用户可以指定如下参数来修改推理配置:
* `data_dir`:数据集目录。目录下应该包括音频列表文件`wav.scp`和抄本文件`text`(可选),具体格式可以参见[快速开始](./get_started.md)中的说明。如果`text`文件存在,则会相应的计算CER,否则会跳过。
* `output_dir`:推理结果保存目录
* `batch_size`:推理时的batch大小
@@ -21,14 +21,14 @@
* `data_path`:数据目录。该目录下应该包括存放训练集数据的`train`目录和存放验证集数据的`dev`目录。每个目录中需要包括音频列表文件`wav.scp`和抄本文件`text`
* `output_dir`:微调结果保存目录
* `dataset_type`:对于小数据集,设置为`small`;当数据量大于1000小时时,设置为`large`
* `batch_bins`:batch size,如果dataset_type设置为`small`,batch_bins单位为fbank特征帧数;如果dataset_type=`large`,batch_bins单位为毫秒
* `batch_bins`:batch size,如果dataset_type设置为`small`,batch_bins单位为fbank特征帧数;如果dataset_type设置为`large`,batch_bins单位为毫秒
* `max_epoch`:最大的训练轮数
以下参数也可以进行设置。但是如果没有特别的需求,可以忽略,直接使用我们给定的默认值:
* `accum_grad`:梯度累积
* `keep_nbest_models`:选择性能最好的`keep_nbest_models`个模型的参数进行平均,得到性能更好的模型
* `optim`:设置微调时的优化器
* `lr`:设置微调时的学习率
* `optim`:设置优化器
* `lr`:设置学习率
* `scheduler`:设置学习率调整策略
* `scheduler_conf`:学习率调整策略的相关参数
* `specaug`:设置谱增广
@@ -37,7 +37,7 @@
除了直接在`finetune.py`中设置参数外,用户也可以通过手动修改模型下载目录下的`finetune.yaml`文件中的参数来修改微调配置。
## 基于微调后的模型推理
我们提供了`infer_after_finetune.py`来实现基于用户自己微调得到的模型进行推理。基于此文件,用户可以基于微调后的模型,对指定的数据集进行推理,得到相应的识别结果。如果同时给定了抄本,则会同时计算CER。在开始推理前,用户可以指定如下参数来修改推理配置:
我们提供了`infer_after_finetune.py`来实现基于用户自己微调得到的模型进行推理。基于此文件,用户可以基于微调后的模型,对指定的数据集进行推理,得到相应的识别结果。如果给定了抄本,则会同时计算CER。在开始推理前,用户可以指定如下参数来修改推理配置:
* `data_dir`:数据集目录。目录下应该包括音频列表文件`wav.scp`和抄本文件`text`(可选)。如果`text`文件存在,则会相应的计算CER,否则会跳过。
* `output_dir`:推理结果保存目录
* `batch_size`:推理时的batch大小
@@ -45,8 +45,8 @@
* `decoding_model_name`:指定用于推理的模型名
以下参数也可以进行设置。但是如果没有特别的需求,可以忽略,直接使用我们给定的默认值:
* `modelscope_model_name`:微调时使用的初始模型
* `modelscope_model_name`:微调时使用的初始模型名
* `required_files`:使用modelscope接口进行推理时需要用到的文件
## 注意事项
部分模型可能在微调、推理时存在一些特有的参数,这部分参数可以在对应目录的README.md文件中找到具体用法。
部分模型可能在微调、推理时存在一些特有的参数,这部分参数可以在对应目录的`README.md`文件中找到具体用法。
egs/alimeeting/diarization/sond/README.md
@@ -1,5 +1,24 @@
# Get Started
To use this example, please execute the first stage of run.sh first to obtain the prepared data and pre-trained models:
```shell
sh run.sh --stage 0 --stop_stage 0
```
Then, you can execute unit_test.py to check the correctness of code:
```shell
python unit_test.py
# you will get the results:
[{'key': 'R8002_M8002_MS802-S0000_0000000_0001600', 'value': 'spk1 [(0.0, 8.88), (10.72, 11.92), (12.64, 15.2)]\nspk2 [(8.8, 9.76)]\nspk3 [(9.6, 10.96), (15.12, 15.68)]\nspk4 [(11.12, 12.72)]'}]
[{'key': 'R8002_M8002_MS802-S0000_0000000_0001600', 'value': 'spk1 [(0.0, 8.88), (10.72, 11.92), (12.64, 15.2)]\nspk2 [(8.8, 9.76)]\nspk3 [(9.6, 10.96), (15.12, 15.68)]\nspk4 [(11.12, 12.72)]'}]
[{'key': 'R8002_M8002_MS802-S0000_0000000_0001600', 'value': 'spk1 [(0.0, 8.88), (10.72, 11.92), (12.64, 15.2)]\nspk2 [(8.8, 9.76)]\nspk3 [(9.6, 10.88), (15.12, 15.68)]\nspk4 [(11.12, 12.72)]'}]
[{'key': 'test0', 'value': 'spk1 [(0.0, 8.88), (10.64, 15.2)]\nspk2 [(8.88, 9.84)]\nspk3 [(9.6, 11.04), (15.12, 15.68)]\nspk4 [(11.2, 11.76)]'}]
```
You can also execute run.sh to reproduce the diarization performance reported in [1]
```shell
sh run.sh --stage 1 --stop_stage 2
```
# Results
You will get a DER about 4.21%, which is reported in [1], Table 6, line "SOND Oracle Profile".
After executing "run.sh", you will get a DER about 4.21%, which is reported in [1], Table 6, line "SOND Oracle Profile".
# Reference
[1] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, Zhihao Du, Shiliang Zhang, 
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md
New file
@@ -0,0 +1,53 @@
# ModelScope Model
## How to finetune and infer using a pretrained Paraformer-large Model
### Finetune
- Modify finetune training related parameters in `finetune.py`
    - <strong>output_dir:</strong> # result dir
    - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
    - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
    - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
    - <strong>max_epoch:</strong> # number of training epoch
    - <strong>lr:</strong> # learning rate
- Then you can run the pipeline to finetune with:
```python
    python finetune.py
```
### Inference
Or you can use the finetuned model for inference directly.
- Setting parameters in `infer.py`
    - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
    - <strong>output_dir:</strong> # result dir
    - <strong>ngpu:</strong> # the number of GPUs for decoding
    - <strong>njob:</strong> # the number of jobs for each GPU
- Then you can run the pipeline to infer with:
```python
    python infer.py
```
- Results
The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
### Inference using local finetuned model
- Modify inference related parameters in `infer_after_finetune.py`
    - <strong>output_dir:</strong> # result dir
    - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
    - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
- Then you can run the pipeline to finetune with:
```python
    python infer_after_finetune.py
```
- Results
The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
New file
@@ -0,0 +1,40 @@
# Paraformer-Large
- Model link: <https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
- Model size: 45M
# Environments
- date: `Tue Feb 13 20:13:22 CST 2023`
- python version: `3.7.12`
- FunASR version: `0.1.0`
- pytorch version: `pytorch 1.7.0`
- Git hash: ``
- Commit date: ``
# Beachmark Results
## result (paper)
beam=20,CER tool:https://github.com/yufan-aslp/AliMeeting
|        model        | Para (M) | Data (hrs) | Eval (CER%) | Test (CER%) |
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
| MFCCA | 45   |   917  |   16.1   | 17.5   |
## result(modelscope)
beam=10
with separating character (src)
|        model        | Para (M) | Data (hrs) | Eval_sp (CER%) | Test_sp (CER%) |
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
| MFCCA | 45   |   917  |   17.1   | 18.6   |
without separating character (src)
|        model        | Para (M) | Data (hrs) | Eval_nosp (CER%) | Test_nosp (CER%) |
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
| MFCCA | 45   |   917  |   16.4   | 18.0   |
## 偏差
Considering the differences of the CER calculation tool and decoding beam size, the results of CER are biased (<0.5%).
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
New file
@@ -0,0 +1,35 @@
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params.data_path)
    kwargs = dict(
        model=params.model,
        model_revision=params.model_revision,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
if __name__ == '__main__':
    params = modelscope_args(model="NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
    params.output_dir = "./checkpoint"              # m模型保存路径
    params.data_path = "./example_data/"            # 数据路径
    params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 1000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 10                           # 最大训练轮数
    params.lr = 0.0001                             # 设置学习率
    params.model_revision = 'v1.0.0'
    modelscope_finetune(params)
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
New file
@@ -0,0 +1,103 @@
import os
import shutil
from multiprocessing import Pool
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
import pdb;
def modelscope_infer_core(output_dir, split_dir, njob, idx):
    output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
    gpu_id = (int(idx) - 1) // njob
    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    inference_pipline = pipeline(
        task=Tasks.auto_speech_recognition,
        model='NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
        model_revision='v1.0.0',
        output_dir=output_dir_job,
        batch_size=1,
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in)
def modelscope_infer(params):
    # prepare for multi-GPU decoding
    ngpu = params["ngpu"]
    njob = params["njob"]
    output_dir = params["output_dir"]
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
    os.mkdir(output_dir)
    split_dir = os.path.join(output_dir, "split")
    os.mkdir(split_dir)
    nj = ngpu * njob
    wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
    with open(wav_scp_file) as f:
        lines = f.readlines()
        num_lines = len(lines)
        num_job_lines = num_lines // nj
    start = 0
    for i in range(nj):
        end = start + num_job_lines
        file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
        with open(file, "w") as f:
            if i == nj - 1:
                f.writelines(lines[start:])
            else:
                f.writelines(lines[start:end])
        start = end
    p = Pool(nj)
    for i in range(nj):
        p.apply_async(modelscope_infer_core,
                      args=(output_dir, split_dir, njob, str(i + 1)))
    p.close()
    p.join()
    # combine decoding results
    best_recog_path = os.path.join(output_dir, "1best_recog")
    os.mkdir(best_recog_path)
    files = ["text", "token", "score"]
    for file in files:
        with open(os.path.join(best_recog_path, file), "w") as f:
            for i in range(nj):
                job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
                with open(job_file) as f_job:
                    lines = f_job.readlines()
                f.writelines(lines)
    # If text exists, compute CER
    text_in = os.path.join(params["data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(best_recog_path, "token")
        text_proc_file2 = os.path.join(best_recog_path, "token_nosep")
        with open(text_proc_file, 'r') as hyp_reader:
                with open(text_proc_file2, 'w') as hyp_writer:
                    for line in hyp_reader:
                        new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
                        hyp_writer.write(new_context+'\n')
        text_in2 = os.path.join(best_recog_path, "ref_text_nosep")
        with open(text_in, 'r') as ref_reader:
            with open(text_in2, 'w') as ref_writer:
                for line in ref_reader:
                    new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
                    ref_writer.write(new_context+'\n')
        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.sp.cer"))
        compute_wer(text_in2, text_proc_file2, os.path.join(best_recog_path, "text.nosp.cer"))
if __name__ == "__main__":
    params = {}
    params["data_dir"] = "./example_data/validation"
    params["output_dir"] = "./output_dir"
    params["ngpu"] = 1
    params["njob"] = 1
    modelscope_infer(params)
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
New file
@@ -0,0 +1,67 @@
import json
import os
import shutil
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
def modelscope_infer_after_finetune(params):
    # prepare for decoding
    pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
    for file_name in params["required_files"]:
        if file_name == "configuration.json":
            with open(os.path.join(pretrained_model_path, file_name)) as f:
                config_dict = json.load(f)
                config_dict["model"]["am_model_name"] = params["decoding_model_name"]
            with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
                json.dump(config_dict, f, indent=4, separators=(',', ': '))
        else:
            shutil.copy(os.path.join(pretrained_model_path, file_name),
                        os.path.join(params["output_dir"], file_name))
    decoding_path = os.path.join(params["output_dir"], "decode_results")
    if os.path.exists(decoding_path):
        shutil.rmtree(decoding_path)
    os.mkdir(decoding_path)
    # decoding
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model=params["output_dir"],
        output_dir=decoding_path,
        batch_size=1
    )
    audio_in = os.path.join(params["data_dir"], "wav.scp")
    inference_pipeline(audio_in=audio_in)
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
    if text_in is not None:
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        text_proc_file2 = os.path.join(decoding_path, "1best_recog/token_nosep")
        with open(text_proc_file, 'r') as hyp_reader:
                with open(text_proc_file2, 'w') as hyp_writer:
                    for line in hyp_reader:
                        new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
                        hyp_writer.write(new_context+'\n')
        text_in2 = os.path.join(decoding_path, "1best_recog/ref_text_nosep")
        with open(text_in, 'r') as ref_reader:
            with open(text_in2, 'w') as ref_writer:
                for line in ref_reader:
                    new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
                    ref_writer.write(new_context+'\n')
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.sp.cer"))
        compute_wer(text_in2, text_proc_file2, os.path.join(decoding_path, "text.nosp.cer"))
if __name__ == '__main__':
    params = {}
    params["modelscope_model_name"] = "NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
    params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
    params["output_dir"] = "./checkpoint"
    params["data_dir"] = "./example_data/validation"
    params["decoding_model_name"] = "valid.acc.ave.pth"
    modelscope_infer_after_finetune(params)
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
New file
@@ -0,0 +1,19 @@
# ModelScope Model
## How to infer using a pretrained Paraformer-large Model
### Inference
You can use the pretrain model for inference directly.
- Setting parameters in `infer.py`
    - <strong>audio_in:</strong> # Support wav, url, bytes, and parsed audio format.
    - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
    - <strong>batch_size:</strong> # Set batch size in inference.
    - <strong>param_dict:</strong> # Set the hotword list in inference.
- Then you can run the pipeline to infer with:
```python
    python infer.py
```
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py
New file
@@ -0,0 +1,21 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
if __name__ == '__main__':
    param_dict = dict()
    param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
    audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
    output_dir = None
    batch_size = 1
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
        output_dir=output_dir,
        batch_size=batch_size,
        param_dict=param_dict)
    rec_result = inference_pipeline(audio_in=audio_in)
    print(rec_result)
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py
New file
@@ -0,0 +1,36 @@
import os
import tempfile
import codecs
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.msdatasets import MsDataset
if __name__ == '__main__':
    param_dict = dict()
    param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
    output_dir = "./output"
    batch_size = 1
    # dataset split ['test']
    ds_dict = MsDataset.load(dataset_name='speech_asr_aishell1_hotwords_testsets', namespace='speech_asr')
    work_dir = tempfile.TemporaryDirectory().name
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)
    wav_file_path = os.path.join(work_dir, "wav.scp")
    with codecs.open(wav_file_path, 'w') as fin:
        for line in ds_dict:
            wav = line["Audio:FILE"]
            idx = wav.split("/")[-1].split(".")[0]
            fin.writelines(idx + " " + wav + "\n")
    audio_in = wav_file_path
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
        output_dir=output_dir,
        batch_size=batch_size,
        param_dict=param_dict)
    rec_result = inference_pipeline(audio_in=audio_in)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
@@ -23,7 +23,7 @@
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in)
    inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
def modelscope_infer(params):
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
@@ -23,7 +23,7 @@
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in)
    inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
def modelscope_infer(params):
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/README.md
@@ -1,6 +1,6 @@
# ModelScope Model
## How to finetune and infer using a pretrained Paraformer-large Model
## How to finetune and infer using a pretrained UniASR Model
### Finetune
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/infer.py
@@ -23,7 +23,7 @@
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in)
    inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
def modelscope_infer(params):
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/infer_after_finetune.py
@@ -34,7 +34,7 @@
        batch_size=1
    )
    audio_in = os.path.join(params["data_dir"], "wav.scp")
    inference_pipeline(audio_in=audio_in)
    inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/infer.py
@@ -9,5 +9,5 @@
        model="damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online",
        output_dir=output_dir,
    )
    rec_result = inference_pipline(audio_in=audio_in)
    rec_result = inference_pipline(audio_in=audio_in, param_dict={"decoding_model":"normal"})
    print(rec_result)
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
@@ -23,7 +23,7 @@
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in)
    inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
def modelscope_infer(params):
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
@@ -34,7 +34,7 @@
        batch_size=1
    )
    audio_in = os.path.join(params["data_dir"], "wav.scp")
    inference_pipeline(audio_in=audio_in)
    inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer.py
@@ -23,7 +23,7 @@
        batch_size=1
    )
    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
    inference_pipline(audio_in=audio_in)
    inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
def modelscope_infer(params):
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/infer_after_finetune.py
@@ -34,7 +34,7 @@
        batch_size=1
    )
    audio_in = os.path.join(params["data_dir"], "wav.scp")
    inference_pipeline(audio_in=audio_in)
    inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "normal"})
    # computer CER if GT text is set
    text_in = os.path.join(params["data_dir"], "text")
funasr/bin/asr_inference_launch.py
@@ -228,6 +228,9 @@
    elif mode == "vad":
        from funasr.bin.vad_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "mfcca":
        from funasr.bin.asr_inference_mfcca import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
@@ -253,6 +256,9 @@
    elif mode == "vad":
        from funasr.bin.vad_inference import inference
        return inference(**kwargs)
    elif mode == "mfcca":
        from funasr.bin.asr_inference_mfcca import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
funasr/bin/asr_inference_mfcca.py
New file
@@ -0,0 +1,764 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
from funasr.modules.beam_search.beam_search import BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
import pdb
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
    'audio_fs': 16000,
    'model_fs': 16000
}
class Speech2Text:
    """Speech2Text class
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
    """
    def __init__(
            self,
            asr_train_config: Union[Path, str] = None,
            asr_model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            lm_train_config: Union[Path, str] = None,
            lm_file: Union[Path, str] = None,
            token_type: str = None,
            bpemodel: str = None,
            device: str = "cpu",
            maxlenratio: float = 0.0,
            minlenratio: float = 0.0,
            batch_size: int = 1,
            dtype: str = "float32",
            beam_size: int = 20,
            ctc_weight: float = 0.5,
            lm_weight: float = 1.0,
            ngram_weight: float = 0.9,
            penalty: float = 0.0,
            nbest: int = 1,
            streaming: bool = False,
            **kwargs,
    ):
        assert check_argument_types()
        # 1. Build ASR model
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        decoder = asr_model.decoder
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        token_list = asr_model.token_list
        scorers.update(
            decoder=decoder,
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = LMTask.build_model_from_file(
                lm_train_config, lm_file, device
            )
            lm.to(device)
            scorers["lm"] = lm.lm
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        # 4. Build BeamSearch object
        # transducer is not supported now
        beam_search_transducer = None
        weights = dict(
            decoder=1.0 - ctc_weight,
            ctc=ctc_weight,
            lm=lm_weight,
            ngram=ngram_weight,
            length_bonus=penalty,
        )
        beam_search = BeamSearch(
            beam_size=beam_size,
            weights=weights,
            scorers=scorers,
            sos=asr_model.sos,
            eos=asr_model.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
        )
        #beam_search.__class__ = BatchBeamSearch
        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
        if token_type is None:
            token_type = asr_train_args.token_type
        if bpemodel is None:
            bpemodel = asr_train_args.bpemodel
        if token_type is None:
            tokenizer = None
        elif token_type == "bpe":
            if bpemodel is not None:
                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
            else:
                tokenizer = None
        else:
            tokenizer = build_tokenizer(token_type=token_type)
        converter = TokenIDConverter(token_list=token_list)
        logging.info(f"Text tokenizer: {tokenizer}")
        self.asr_model = asr_model
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.beam_search = beam_search
        self.beam_search_transducer = beam_search_transducer
        self.maxlenratio = maxlenratio
        self.minlenratio = minlenratio
        self.device = device
        self.dtype = dtype
        self.nbest = nbest
    @torch.no_grad()
    def __call__(
            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
    ) -> List[
        Tuple[
            Optional[str],
            List[str],
            List[int],
            Union[Hypothesis],
        ]
    ]:
        """Inference
        Args:
            speech: Input speech data
        Returns:
            text, token, token_int, hyp
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        #speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        speech = speech.to(getattr(torch, self.dtype))
        # lenghts: (1,)
        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
        batch = {"speech": speech, "speech_lengths": lengths}
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, _ = self.asr_model.encode(**batch)
        assert len(enc) == 1, len(enc)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = self.converter.ids2tokens(token_int)
            if self.tokenizer is not None:
                text = self.tokenizer.tokens2text(token)
            else:
                text = None
            results.append((text, token, token_int, hyp))
        assert check_return_type(results)
        return results
# def inference(
#         maxlenratio: float,
#         minlenratio: float,
#         batch_size: int,
#         beam_size: int,
#         ngpu: int,
#         ctc_weight: float,
#         lm_weight: float,
#         penalty: float,
#         log_level: Union[int, str],
#         data_path_and_name_and_type,
#         asr_train_config: Optional[str],
#         asr_model_file: Optional[str],
#         cmvn_file: Optional[str] = None,
#         lm_train_config: Optional[str] = None,
#         lm_file: Optional[str] = None,
#         token_type: Optional[str] = None,
#         key_file: Optional[str] = None,
#         word_lm_train_config: Optional[str] = None,
#         bpemodel: Optional[str] = None,
#         allow_variable_data_keys: bool = False,
#         streaming: bool = False,
#         output_dir: Optional[str] = None,
#         dtype: str = "float32",
#         seed: int = 0,
#         ngram_weight: float = 0.9,
#         nbest: int = 1,
#         num_workers: int = 1,
#         **kwargs,
# ):
#     assert check_argument_types()
#     if batch_size > 1:
#         raise NotImplementedError("batch decoding is not implemented")
#     if word_lm_train_config is not None:
#         raise NotImplementedError("Word LM is not implemented")
#     if ngpu > 1:
#         raise NotImplementedError("only single GPU decoding is supported")
#
#     logging.basicConfig(
#         level=log_level,
#         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
#     )
#
#     if ngpu >= 1 and torch.cuda.is_available():
#         device = "cuda"
#     else:
#         device = "cpu"
#
#     # 1. Set random-seed
#     set_all_random_seed(seed)
#
#     # 2. Build speech2text
#     speech2text_kwargs = dict(
#         asr_train_config=asr_train_config,
#         asr_model_file=asr_model_file,
#         cmvn_file=cmvn_file,
#         lm_train_config=lm_train_config,
#         lm_file=lm_file,
#         token_type=token_type,
#         bpemodel=bpemodel,
#         device=device,
#         maxlenratio=maxlenratio,
#         minlenratio=minlenratio,
#         dtype=dtype,
#         beam_size=beam_size,
#         ctc_weight=ctc_weight,
#         lm_weight=lm_weight,
#         ngram_weight=ngram_weight,
#         penalty=penalty,
#         nbest=nbest,
#         streaming=streaming,
#     )
#     logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
#     speech2text = Speech2Text(**speech2text_kwargs)
#
#     # 3. Build data-iterator
#     loader = ASRTask.build_streaming_iterator(
#         data_path_and_name_and_type,
#         dtype=dtype,
#         batch_size=batch_size,
#         key_file=key_file,
#         num_workers=num_workers,
#         preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
#         collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
#         allow_variable_data_keys=allow_variable_data_keys,
#         inference=True,
#     )
#
#     finish_count = 0
#     file_count = 1
#     # 7 .Start for-loop
#     # FIXME(kamo): The output format should be discussed about
#     asr_result_list = []
#     if output_dir is not None:
#         writer = DatadirWriter(output_dir)
#     else:
#         writer = None
#
#     for keys, batch in loader:
#         assert isinstance(batch, dict), type(batch)
#         assert all(isinstance(s, str) for s in keys), keys
#         _bs = len(next(iter(batch.values())))
#         assert len(keys) == _bs, f"{len(keys)} != {_bs}"
#         #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
#
#         # N-best list of (text, token, token_int, hyp_object)
#         try:
#             results = speech2text(**batch)
#         except TooShortUttError as e:
#             logging.warning(f"Utterance {keys} {e}")
#             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
#             results = [[" ", ["<space>"], [2], hyp]] * nbest
#
#         # Only supporting batch_size==1
#         key = keys[0]
#         for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
#             # Create a directory: outdir/{n}best_recog
#             if writer is not None:
#                 ibest_writer = writer[f"{n}best_recog"]
#
#                 # Write the result to each file
#                 ibest_writer["token"][key] = " ".join(token)
#                 ibest_writer["token_int"][key] = " ".join(map(str, token_int))
#                 ibest_writer["score"][key] = str(hyp.score)
#
#             if text is not None:
#                 text_postprocessed = postprocess_utils.sentence_postprocess(token)
#                 item = {'key': key, 'value': text_postprocessed}
#                 asr_result_list.append(item)
#                 finish_count += 1
#                 asr_utils.print_progress(finish_count / file_count)
#                 if writer is not None:
#                     ibest_writer["text"][key] = text
#     return asr_result_list
def inference(
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        streaming: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
        batch_size=batch_size,
        beam_size=beam_size,
        ngpu=ngpu,
        ctc_weight=ctc_weight,
        lm_weight=lm_weight,
        penalty=penalty,
        log_level=log_level,
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        raw_inputs=raw_inputs,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        key_file=key_file,
        word_lm_train_config=word_lm_train_config,
        bpemodel=bpemodel,
        allow_variable_data_keys=allow_variable_data_keys,
        streaming=streaming,
        output_dir=output_dir,
        dtype=dtype,
        seed=seed,
        ngram_weight=ngram_weight,
        nbest=nbest,
        num_workers=num_workers,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
    maxlenratio: float,
    minlenratio: float,
    batch_size: int,
    beam_size: int,
    ngpu: int,
    ctc_weight: float,
    lm_weight: float,
    penalty: float,
    log_level: Union[int, str],
    # data_path_and_name_and_type,
    asr_train_config: Optional[str],
    asr_model_file: Optional[str],
    cmvn_file: Optional[str] = None,
    lm_train_config: Optional[str] = None,
    lm_file: Optional[str] = None,
    token_type: Optional[str] = None,
    key_file: Optional[str] = None,
    word_lm_train_config: Optional[str] = None,
    bpemodel: Optional[str] = None,
    allow_variable_data_keys: bool = False,
    streaming: bool = False,
    output_dir: Optional[str] = None,
    dtype: str = "float32",
    seed: int = 0,
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        bpemodel=bpemodel,
        device=device,
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
        dtype=dtype,
        beam_size=beam_size,
        ctc_weight=ctc_weight,
        lm_weight=lm_weight,
        ngram_weight=ngram_weight,
        penalty=penalty,
        nbest=nbest,
        streaming=streaming,
    )
    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
    speech2text = Speech2Text(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
        # FIXME(kamo): The output format should be discussed about
        asr_result_list = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            writer = DatadirWriter(output_path)
        else:
            writer = None
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # N-best list of (text, token, token_int, hyp_object)
            try:
                results = speech2text(**batch)
            except TooShortUttError as e:
                logging.warning(f"Utterance {keys} {e}")
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["<space>"], [2], hyp]] * nbest
            # Only supporting batch_size==1
            key = keys[0]
            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"{n}best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["score"][key] = str(hyp.score)
                if text is not None:
                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
                    item = {'key': key, 'value': text_postprocessed}
                    asr_result_list.append(item)
                    finish_count += 1
                    asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text
        return asr_result_list
    return _forward
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="ASR Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=False,
        action="append",
    )
    group.add_argument("--raw_inputs", type=list, default=None)
    # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--asr_train_config",
        type=str,
        help="ASR training configuration",
    )
    group.add_argument(
        "--asr_model_file",
        type=str,
        help="ASR model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global cmvn file",
    )
    group.add_argument(
        "--lm_train_config",
        type=str,
        help="LM training configuration",
    )
    group.add_argument(
        "--lm_file",
        type=str,
        help="LM parameter file",
    )
    group.add_argument(
        "--word_lm_train_config",
        type=str,
        help="Word LM training configuration",
    )
    group.add_argument(
        "--word_lm_file",
        type=str,
        help="Word LM parameter file",
    )
    group.add_argument(
        "--ngram_file",
        type=str,
        help="N-gram parameter file",
    )
    group.add_argument(
        "--model_tag",
        type=str,
        help="Pretrained model tag. If specify this option, *_train_config and "
             "*_file will be overwritten",
    )
    group = parser.add_argument_group("Beam-search related")
    group.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
    group.add_argument(
        "--maxlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain max output length. "
             "If maxlenratio=0.0 (default), it uses a end-detect "
             "function "
             "to automatically find maximum hypothesis lengths."
             "If maxlenratio<0.0, its absolute value is interpreted"
             "as a constant max output length",
    )
    group.add_argument(
        "--minlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain min output length",
    )
    group.add_argument(
        "--ctc_weight",
        type=float,
        default=0.5,
        help="CTC weight in joint decoding",
    )
    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
    group.add_argument("--streaming", type=str2bool, default=False)
    group = parser.add_argument_group("Text converter related")
    group.add_argument(
        "--token_type",
        type=str_or_none,
        default=None,
        choices=["char", "bpe", None],
        help="The token type for ASR model. "
             "If not given, refers from the training args",
    )
    group.add_argument(
        "--bpemodel",
        type=str_or_none,
        default=None,
        help="The model path of sentencepiece. "
             "If not given, refers from the training args",
    )
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    inference(**kwargs)
if __name__ == "__main__":
    main()
funasr/bin/asr_inference_paraformer.py
@@ -6,6 +6,8 @@
import copy
import os
import codecs
import tempfile
import requests
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -175,10 +177,24 @@
        self.converter = converter
        self.tokenizer = tokenizer
        # 6. [Optional] Build hotword list from file or str
        # 6. [Optional] Build hotword list from str, local file or url
        # for None
        if hotword_list_or_file is None:
            self.hotword_list = None
        # for text str input
        elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
            logging.info("Attempting to parse hotwords as str...")
            self.hotword_list = []
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
            self.hotword_list.append([self.asr_model.sos])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))
        # for local txt inputs
        elif os.path.exists(hotword_list_or_file):
            logging.info("Attempting to parse hotwords from local txt...")
            self.hotword_list = []
            hotword_str_list = []
            with codecs.open(hotword_list_or_file, 'r') as fin:
@@ -186,20 +202,31 @@
                    hw = line.strip()
                    hotword_str_list.append(hw)
                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                self.hotword_list.append([1])
                self.hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
                .format(hotword_list_or_file, hotword_str_list))
        # for url, download and generate txt
        else:
            logging.info("Attempting to parse hotwords as str...")
            logging.info("Attempting to parse hotwords from url...")
            work_dir = tempfile.TemporaryDirectory().name
            if not os.path.exists(work_dir):
                os.makedirs(work_dir)
            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
            local_file = requests.get(hotword_list_or_file)
            open(text_file_path, "wb").write(local_file.content)
            hotword_list_or_file = text_file_path
            self.hotword_list = []
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
            self.hotword_list.append([1])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hotword_str_list.append(hw)
                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                self.hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
                .format(hotword_list_or_file, hotword_str_list))
        is_use_lm = lm_weight != 0.0 and lm_file is not None
funasr/bin/asr_inference_paraformer_timestamp.py
@@ -455,16 +455,6 @@
    return asr_result_list
def set_parameters(language: str = None,
                   sample_rate: Union[int, Dict[Any, int]] = None):
    if language is not None:
        global global_asr_language
        global_asr_language = language
    if sample_rate is not None:
        global global_sample_rate
        global_sample_rate = sample_rate
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="ASR Decoding",
funasr/bin/asr_inference_paraformer_vad.py
@@ -38,7 +38,6 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.bin.punctuation_infer import Text2Punc
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -39,7 +39,7 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer
@@ -282,12 +282,8 @@
                else:
                    text = None
                if isinstance(self.asr_model, BiCifParaformer):
                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
                else:
                    time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time)
                    results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
                timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
                results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
        # assert check_return_type(results)
        return results
@@ -617,7 +613,7 @@
                result = result_segments[0]
                text, token, token_int = result[0], result[1], result[2]
                time_stamp = None if len(result) < 4 else result[3]
                if use_timestamp and time_stamp is not None: 
                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
                else:
funasr/bin/asr_inference_uniasr.py
@@ -397,7 +397,7 @@
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
@@ -439,6 +439,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        if param_dict is not None and "decoding_model" in param_dict:
            if param_dict["decoding_model"] == "fast":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model1"
            elif param_dict["decoding_model"] == "normal":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model2"
            elif param_dict["decoding_model"] == "offline":
                speech2text.decoding_ind = 1
                speech2text.decoding_mode = "model2"
            else:
                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
funasr/bin/asr_inference_uniasr_vad.py
@@ -439,6 +439,18 @@
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        if param_dict is not None and "decoding_model" in param_dict:
            if param_dict["decoding_model"] == "fast":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model1"
            elif param_dict["decoding_model"] == "normal":
                speech2text.decoding_ind = 0
                speech2text.decoding_mode = "model2"
            elif param_dict["decoding_model"] == "offline":
                speech2text.decoding_ind = 1
                speech2text.decoding_mode = "model2"
            else:
                raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
        loader = ASRTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
funasr/bin/build_trainer.py
@@ -27,6 +27,8 @@
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "uniasr":
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
    elif mode == "mfcca":
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
    else:
        raise ValueError("Unknown mode: {}".format(mode))
    parser = ASRTask.get_parser()
funasr/bin/vad_inference.py
@@ -1,6 +1,7 @@
import argparse
import logging
import sys
import json
from pathlib import Path
from typing import Any
from typing import List
@@ -105,17 +106,32 @@
            feats_len = feats_len.int()
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        batch = {"feats": feats, "feats_lengths": feats_len, "waveform": speech}
        # batch = {"feats": feats, "waveform": speech, "is_final_send": True}
        # segments = self.vad_model(**batch)
        # a. To device
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        segments = self.vad_model(**batch)
        # b. Forward Encoder sreaming
        segments = []
        step = 6000
        t_offset = 0
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
                is_final_send = True
            else:
                is_final_send = False
            batch = {
                "feats": feats[:, t_offset:t_offset + step, :],
                "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
                "is_final_send": is_final_send
            }
            # a. To device
            batch = to_device(batch, device=self.device)
            segments_part = self.vad_model(**batch)
            if segments_part:
                segments += segments_part
        #print(segments)
        return segments
def inference(
@@ -152,11 +168,12 @@
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
        batch_size: int,
        ngpu: int,
        log_level: Union[int, str],
        #data_path_and_name_and_type,
        # data_path_and_name_and_type,
        vad_infer_config: Optional[str],
        vad_model_file: Optional[str],
        vad_cmvn_file: Optional[str] = None,
@@ -167,7 +184,6 @@
        dtype: str = "float32",
        seed: int = 0,
        num_workers: int = 1,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
@@ -201,11 +217,11 @@
    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
        output_dir_v2: Optional[str] = None,
        fs: dict = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
            output_dir_v2: Optional[str] = None,
            fs: dict = None,
            param_dict: dict = None,
    ):
        # 3. Build data-iterator
        loader = VADTask.build_streaming_iterator(
@@ -243,9 +259,11 @@
            # do vad segment
            results = speech2vadsegment(**batch)
            for i, _ in enumerate(keys):
                results[i] = json.dumps(results[i])
                item = {'key': keys[i], 'value': results[i]}
                vad_results.append(item)
                if writer is not None:
                    results[i] = json.loads(results[i])
                    ibest_writer["text"][keys[i]] = "{}".format(results[i])
        return vad_results
funasr/bin/vad_inference_launch.py
@@ -107,13 +107,15 @@
def inference_launch(mode, **kwargs):
    if mode == "vad":
    if mode == "offline":
        from funasr.bin.vad_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "online":
        from funasr.bin.vad_inference_online import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
funasr/datasets/iterable_dataset.py
@@ -174,90 +174,94 @@
    def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
        count = 0
        if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
            linenum = len(self.path_name_type_list)
            data = {}
            value = self.path_name_type_list[0][0]
            uid = 'utt_id'
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            func = DATA_TYPES[_type]
            array = func(value)
            if self.fs is not None and name == "speech":
                audio_fs = self.fs["audio_fs"]
                model_fs = self.fs["model_fs"]
                if audio_fs is not None and model_fs is not None:
                    array = torch.from_numpy(array)
                    array = array.unsqueeze(0)
                    array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                   new_freq=model_fs)(array)
                    array = array.squeeze(0).numpy()
            data[name] = array
            for i in range(linenum):
                value = self.path_name_type_list[i][0]
                uid = 'utt_id'
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
                    audio_fs = self.fs["audio_fs"]
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = array.unsqueeze(0)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                       new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                for name in data:
                    count += 1
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f'All values must be converted to np.ndarray object '
                            f'by preprocessing, but "{name}" is still {type(value)}.')
                    # Cast to desired type
                    if value.dtype.kind == 'f':
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == 'i':
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(
                            f'Not supported dtype: {value.dtype}')
                    data[name] = value
            yield uid, data
        elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
            linenum = len(self.path_name_type_list)
            data = {}
            value = self.path_name_type_list[0][0]
            uid = os.path.basename(self.path_name_type_list[0][0]).split(".")[0]
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            if _type == "sound":
                audio_type = os.path.basename(value).split(".")[1].lower()
                if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                    raise NotImplementedError(
                        f'Not supported audio type: {audio_type}')
                if audio_type == "pcm":
                    _type = "pcm"
            for i in range(linenum):
                value = self.path_name_type_list[i][0]
                uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                if _type == "sound":
                    audio_type = os.path.basename(value).split(".")[-1].lower()
                    if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                        raise NotImplementedError(
                            f'Not supported audio type: {audio_type}')
                    if audio_type == "pcm":
                        _type = "pcm"
            func = DATA_TYPES[_type]
            array = func(value)
            if self.fs is not None and name == "speech":
                audio_fs = self.fs["audio_fs"]
                model_fs = self.fs["model_fs"]
                if audio_fs is not None and model_fs is not None:
                    array = torch.from_numpy(array)
                    array = array.unsqueeze(0)
                    array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                           new_freq=model_fs)(array)
                    array = array.squeeze(0).numpy()
            data[name] = array
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
                    audio_fs = self.fs["audio_fs"]
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = array.unsqueeze(0)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                               new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                for name in data:
                    count += 1
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f'All values must be converted to np.ndarray object '
                            f'by preprocessing, but "{name}" is still {type(value)}.')
                    # Cast to desired type
                    if value.dtype.kind == 'f':
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == 'i':
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(
                            f'Not supported dtype: {value.dtype}')
                    data[name] = value
            yield uid, data
@@ -322,7 +326,7 @@
                # 2.a. Load data streamingly
                for value, (path, name, _type) in zip(values, self.path_name_type_list):
                    if _type == "sound":
                        audio_type = os.path.basename(value).split(".")[1].lower()
                        audio_type = os.path.basename(value).split(".")[-1].lower()
                        if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                            raise NotImplementedError(
                                f'Not supported audio type: {audio_type}')
funasr/datasets/large_datasets/dataset.py
@@ -1,9 +1,10 @@
import os
import random
import soundfile
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -117,7 +118,9 @@
                            sample_dict["key"] = key
                    elif data_type == "sound":
                        key, path = item.strip().split()
                        mat, sampling_rate = soundfile.read(path)
                        waveform, sampling_rate = torchaudio.load(path)
                        waveform = waveform.numpy()
                        mat = waveform[0]
                        sample_dict[data_name] = mat
                        sample_dict["sampling_rate"] = sampling_rate
                        if data_name == "speech":
funasr/datasets/preprocessor.py
@@ -363,7 +363,7 @@
            if self.split_with_space:
                tokens = text.strip().split(" ")
                if self.seg_dict is not None:
                    tokens = forward_segment("".join(tokens).lower(), self.seg_dict)
                    tokens = forward_segment("".join(tokens), self.seg_dict)
                    tokens = seg_tokenize(tokens, self.seg_dict)
            else:
                tokens = self.tokenizer.text2tokens(text)
funasr/export/README.md
@@ -9,35 +9,35 @@
The installation is the same as [funasr](../../README.md)
## Export onnx format model
## Export model
   `Tips`: torch 1.11.0 is required.
   ```shell
   python -m funasr.export.export_model [model_name] [export_dir] [onnx]
   ```
   `model_name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb).
   `export_dir`: the dir where the onnx is export.
    `onnx`: `true`, export onnx format model; `false`, export torchscripts format model.
## For example
### Export onnx format model
Export model from modelscope
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export"  # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```shell
python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
```
Export model from local path, the model'name must be `model.pb`.
```shell
python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
```
Export model from local path
```python
export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
## Export torchscripts format model
### Export torchscripts format model
Export model from modelscope
```python
from funasr.export.export_model import ASRModelExportParaformer
output_dir = "../export"  # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```shell
python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
```
Export model from local path
```python
export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
Export model from local path, the model'name must be `model.pb`.
```shell
python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
```
funasr/export/export_model.py
@@ -24,7 +24,7 @@
            feats_dim=560,
            onnx=False,
        )
        logging.info("output dir: {}".format(self.cache_dir))
        print("output dir: {}".format(self.cache_dir))
        self.onnx = onnx
        
@@ -44,13 +44,13 @@
            model,
            self.export_config,
        )
        self._export_onnx(model, verbose, export_dir)
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(model, verbose, export_dir)
        else:
            self._export_torchscripts(model, verbose, export_dir)
        logging.info("output dir: {}".format(export_dir))
        print("output dir: {}".format(export_dir))
    def _export_torchscripts(self, model, verbose, path, enc_size=None):
@@ -117,7 +117,15 @@
        )
if __name__ == '__main__':
    output_dir = "../export"
    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
    export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
    import sys
    model_path = sys.argv[1]
    output_dir = sys.argv[2]
    onnx = sys.argv[3]
    onnx = onnx.lower()
    onnx = onnx == 'true'
    # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
    # output_dir = "../export"
    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx)
    export_model.export(model_path)
    # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
funasr/export/models/e2e_asr_paraformer.py
@@ -59,7 +59,7 @@
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.round().long()
        pre_token_length = pre_token_length.round().type(torch.int32)
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
funasr/export/models/predictor/cif.py
@@ -116,53 +116,3 @@
        pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
        list_ls.append(torch.cat([l, pad_l], 0))
    return torch.stack(list_ls, 0), fires
def CifPredictorV2_test():
    x = torch.rand([2, 21, 2])
    x_len = torch.IntTensor([6, 21])
    mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
    x = x * mask[:, :, None]
    predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
    # cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
    predictor_scripts.save('test.pt')
    loaded = torch.jit.load('test.pt')
    cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
    # print(cif_output)
    print(predictor_scripts.code)
    # predictor = CifPredictorV2(2, 1, 1)
    # cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
    print(cif_output)
def CifPredictorV2_export_test():
    x = torch.rand([2, 21, 2])
    x_len = torch.IntTensor([6, 21])
    mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
    x = x * mask[:, :, None]
    # predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
    # cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
    predictor = CifPredictorV2(2, 1, 1)
    predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
    predictor_trace.save('test_trace.pt')
    loaded = torch.jit.load('test_trace.pt')
    x = torch.rand([3, 30, 2])
    x_len = torch.IntTensor([6, 20, 30])
    mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
    x = x * mask[:, :, None]
    cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
    print(cif_output)
    # print(predictor_trace.code)
    # predictor = CifPredictorV2(2, 1, 1)
    # cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
    # print(cif_output)
if __name__ == '__main__':
    # CifPredictorV2_test()
    CifPredictorV2_export_test()
funasr/models/e2e_asr_mfcca.py
New file
@@ -0,0 +1,322 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import torch
from typeguard import check_argument_types
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
import pdb
import random
import math
class MFCCA(AbsESPnetModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        decoder: AbsDecoder,
        ctc: CTC,
        rnnt_decoder: None,
        ctc_weight: float = 0.5,
        ignore_id: int = -1,
        lsm_weight: float = 0.0,
        mask_ratio: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert rnnt_decoder is None, "Not implemented"
        super().__init__()
        # note that eos is the same as sos (equivalent ID)
        self.sos = vocab_size - 1
        self.eos = vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.token_list = token_list.copy()
        self.mask_ratio = mask_ratio
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.encoder = encoder
        # we set self.decoder = None in the CTC mode since
        # self.decoder parameters were never used and PyTorch complained
        # and threw an Exception in the multi-GPU experiment.
        # thanks Jeff Farris for pointing out the issue.
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        self.rnnt_decoder = rnnt_decoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        if report_cer or report_wer:
            self.error_calculator = ErrorCalculator(
                token_list, sym_space, sym_blank, report_cer, report_wer
            )
        else:
            self.error_calculator = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        #pdb.set_trace()
        if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
            rate_num = random.random()
            #rate_num = 0.1
            if(rate_num<=self.mask_ratio):
                retain_channel = math.ceil(random.random() *8)
                if(retain_channel>1):
                    speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
                else:
                    speech = speech[:,:,torch.randperm(8)[0]]
        #pdb.set_trace()
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # 2a. Attention-decoder branch
        if self.ctc_weight == 1.0:
            loss_att, acc_att, cer_att, wer_att = None, None, None, None
        else:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
        # 2b. CTC branch
        if self.ctc_weight == 0.0:
            loss_ctc, cer_ctc = None, None
        else:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
        # 2c. RNN-T branch
        if self.rnnt_decoder is not None:
            _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
            acc=acc_att,
            cer=cer_att,
            wer=wer_att,
            cer_ctc=cer_ctc,
        )
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        """
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
            # 2. Data augmentation
            if self.specaug is not None and self.training:
                feats, feats_lengths = self.specaug(feats, feats_lengths)
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
        #pdb.set_trace()
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        if(encoder_out.dim()==4):
            assert encoder_out.size(2) <= encoder_out_lens.max(), (
                encoder_out.size(),
                encoder_out_lens.max(),
            )
        else:
            assert encoder_out.size(1) <= encoder_out_lens.max(), (
                encoder_out.size(),
                encoder_out_lens.max(),
            )
        return encoder_out, encoder_out_lens
    def _extract_feats(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # for data-parallel
        speech = speech[:, : speech_lengths.max()]
        if self.frontend is not None:
            # Frontend
            #  e.g. STFT and Feature extract
            #       data_loader may send time-domain signal in this case
            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
            feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
        else:
            # No frontend and no feature extract
            feats, feats_lengths = speech, speech_lengths
            channel_size = 1
        return feats, feats_lengths, channel_size
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
        # 1. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        if(encoder_out.dim()==4):
            encoder_out = encoder_out.mean(1)
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
    def _calc_rnnt_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        raise NotImplementedError
funasr/models/e2e_vad.py
@@ -5,7 +5,6 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
# from checkpoint import load_checkpoint
class VadStateMachine(Enum):
@@ -136,7 +135,7 @@
        self.win_size_frame = int(window_size_ms / frame_size_ms)
        self.win_sum = 0
        self.win_state = [0 for i in range(0, self.win_size_frame)]  # 初始化窗
        self.win_state = [0] * self.win_size_frame  # 初始化窗
        self.cur_win_pos = 0
        self.pre_frame_state = FrameState.kFrameStateSil
@@ -151,7 +150,7 @@
    def Reset(self) -> None:
        self.cur_win_pos = 0
        self.win_sum = 0
        self.win_state = [0 for i in range(0, self.win_size_frame)]
        self.win_state = [0] * self.win_size_frame
        self.pre_frame_state = FrameState.kFrameStateSil
        self.cur_frame_state = FrameState.kFrameStateSil
        self.voice_last_frame_count = 0
@@ -192,8 +191,8 @@
        return int(self.frame_size_ms)
class E2EVadModel(torch.nn.Module):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
class E2EVadModel(nn.Module):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
        super(E2EVadModel, self).__init__()
        self.vad_opts = VADXOptions(**vad_post_args)
        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -212,13 +211,13 @@
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.is_callback_with_sign = False
        self.sil_frame = 0
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
        self.speech_noise_thres = self.vad_opts.speech_noise_thres
@@ -226,10 +225,13 @@
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.streaming = streaming
        self.ResetDetection()
    def AllResetDetection(self):
        self.encoder.cache_reset()  # reset the in_cache in self.encoder for next query or next long sentence
        self.is_final_send = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
@@ -240,13 +242,13 @@
        self.confirmed_start_frame = -1
        self.confirmed_end_frame = -1
        self.number_end_time_detected = 0
        self.is_callback_with_sign = False
        self.sil_frame = 0
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.output_data_buf = []
        self.output_data_buf_offset = 0
        self.frame_probs = []
        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
        self.speech_noise_thres = self.vad_opts.speech_noise_thres
@@ -254,6 +256,7 @@
        self.max_time_out = False
        self.decibel = []
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
@@ -271,26 +274,32 @@
    def ComputeDecibel(self) -> None:
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        self.data_buf = self.waveform[0]  # 指向self.waveform[0]
        if self.data_buf_all is None:
            self.data_buf_all = self.waveform[0]  # self.data_buf is pointed to self.waveform[0]
            self.data_buf = self.data_buf_all
        else:
            self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
        for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
            self.decibel.append(
                10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
    def ComputeScores(self, feats: torch.Tensor, feats_lengths: int) -> None:
        self.scores = self.encoder(feats)  # return B * T * D
        self.frm_cnt = feats_lengths # frame
        # return self.scores
    def ComputeScores(self, feats: torch.Tensor) -> None:
        scores = self.encoder(feats)  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        self.frm_cnt += scores.shape[1]  # count total frames
        if self.scores is None:
            self.scores = scores  # the first calculation
        else:
            self.scores = torch.cat((self.scores, scores), dim=1)
    def PopDataBufTillFrame(self, frame_idx: int) -> None:  # need check again
        while self.data_buf_start_frame < frame_idx:
            if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                self.data_buf_start_frame += 1
                self.data_buf = self.waveform[0][self.data_buf_start_frame * int(
                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
                # for i in range(0, int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)):
                #     self.data_buf.popleft()
                # self.data_buf_start_frame += 1
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                           last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
@@ -301,8 +310,9 @@
                               self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
            expected_sample_number += int(extra_sample)
        if end_point_is_sent_end:
            # expected_sample_number = max(expected_sample_number, len(self.data_buf))
            pass
            expected_sample_number = max(expected_sample_number, len(self.data_buf))
        if len(self.data_buf) < expected_sample_number:
            print('error in calling pop data_buf\n')
        if len(self.output_data_buf) == 0 or first_frm_is_start_point:
            self.output_data_buf.append(E2EVadSpeechBufWithDoa())
@@ -312,15 +322,18 @@
            self.output_data_buf[-1].doa = 0
        cur_seg = self.output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning')
            print('warning\n')
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
        else:
            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        # if data_to_pop > len(self.data_buf_)
        #   pass
        if data_to_pop > len(self.data_buf):
            print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
            data_to_pop = len(self.data_buf)
            expected_sample_number = len(self.data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
@@ -329,7 +342,7 @@
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print('warning')
            print('Something wrong with the VAD algorithm\n')
        self.data_buf_start_frame += frm_cnt
        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
        if first_frm_is_start_point:
@@ -346,14 +359,13 @@
    def OnVoiceDetected(self, valid_frame: int) -> None:
        self.latest_confirmed_speech_frame = valid_frame
        if True:  # is_new_api_enable_ = True
            self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
        if self.vad_opts.do_start_point_detection:
            pass
        if self.confirmed_start_frame != -1:
            print('warning')
            print('not reset vad properly\n')
        else:
            self.confirmed_start_frame = start_frame
@@ -366,7 +378,7 @@
        if self.vad_opts.do_end_point_detection:
            pass
        if self.confirmed_end_frame != -1:
            print('warning')
            print('not reset vad properly\n')
        else:
            self.confirmed_end_frame = end_frame
        if not fake_result:
@@ -406,7 +418,6 @@
            sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
            sum_score = sum(sil_pdf_scores)
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            # total_score = sum(self.scores[0][t][:])
            total_score = 1.0
            sum_score = total_score - sum_score
        speech_prob = math.log(sum_score)
@@ -433,25 +444,59 @@
        return frame_state
    def forward(self, feats: torch.Tensor, feats_lengths: int, waveform: torch.tensor) -> List[List[List[int]]]:
        self.AllResetDetection()
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, feats_lengths)
        assert len(self.decibel) == len(self.scores[0])  # 保证帧数一致
        self.DetectLastFrames()
        self.ComputeScores(feats)
        if not is_final_send:
            self.DetectCommonFrames()
        else:
            if self.streaming:
                self.DetectLastFrames()
            else:
                self.AllResetDetection()
                self.DetectAllFrames()  # offline decode and is_final_send == True
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            for i in range(0, len(self.output_data_buf)):
                segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                segment_batch.append(segment)
            segments.append(segment_batch)
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
                        i].contain_seg_end_point:
                        segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                        segment_batch.append(segment)
                        self.output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        return segments
    def DetectCommonFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
        return 0
    def DetectLastFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            if i != 0:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
            else:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
        return 0
    def DetectAllFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
            frame_state = FrameState.kFrameStateInvalid
            for t in range(0, self.frm_cnt):
funasr/models/encoder/encoder_layer_mfcca.py
New file
@@ -0,0 +1,270 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
#                Northwestern Polytechnical University (Pengcheng Guo)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from funasr.modules.layer_norm import LayerNorm
from torch.autograd import Variable
class Encoder_Conformer_Layer(nn.Module):
    """Encoder layer module.
    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
            can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        conv_module (torch.nn.Module): Convolution module instance.
            `ConvlutionModule` instance can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
    """
    def __init__(
        self,
        size,
        self_attn,
        feed_forward,
        feed_forward_macaron,
        conv_module,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
        cca_pos=0,
    ):
        """Construct an Encoder_Conformer_Layer object."""
        super(Encoder_Conformer_Layer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.feed_forward_macaron = feed_forward_macaron
        self.conv_module = conv_module
        self.norm_ff = LayerNorm(size)  # for the FNN module
        self.norm_mha = LayerNorm(size)  # for the MHA module
        if feed_forward_macaron is not None:
            self.norm_ff_macaron = LayerNorm(size)
            self.ff_scale = 0.5
        else:
            self.ff_scale = 1.0
        if self.conv_module is not None:
            self.norm_conv = LayerNorm(size)  # for the CNN module
            self.norm_final = LayerNorm(size)  # for the final output of the block
        self.dropout = nn.Dropout(dropout_rate)
        self.size = size
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        self.cca_pos = cca_pos
        if self.concat_after:
            self.concat_linear = nn.Linear(size + size, size)
    def forward(self, x_input, mask, cache=None):
        """Compute encoded features.
        Args:
            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
                - w/o pos emb: Tensor (#batch, time, size).
            mask (torch.Tensor): Mask tensor for the input (#batch, time).
            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
        Returns:
            torch.Tensor: Output tensor (#batch, time, size).
            torch.Tensor: Mask tensor (#batch, time).
        """
        if isinstance(x_input, tuple):
            x, pos_emb = x_input[0], x_input[1]
        else:
            x, pos_emb = x_input, None
        # whether to use macaron style
        if self.feed_forward_macaron is not None:
            residual = x
            if self.normalize_before:
                x = self.norm_ff_macaron(x)
            x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
            if not self.normalize_before:
                x = self.norm_ff_macaron(x)
        # multi-headed self-attention module
        residual = x
        if self.normalize_before:
            x = self.norm_mha(x)
        if cache is None:
            x_q = x
        else:
            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
            x_q = x[:, -1:, :]
            residual = residual[:, -1:, :]
            mask = None if mask is None else mask[:, -1:, :]
        if self.cca_pos<2:
            if pos_emb is not None:
                x_att = self.self_attn(x_q, x, x, pos_emb, mask)
            else:
                x_att = self.self_attn(x_q, x, x, mask)
        else:
            x_att = self.self_attn(x_q, x, x, mask)
        if self.concat_after:
            x_concat = torch.cat((x, x_att), dim=-1)
            x = residual + self.concat_linear(x_concat)
        else:
            x = residual + self.dropout(x_att)
        if not self.normalize_before:
            x = self.norm_mha(x)
        # convolution module
        if self.conv_module is not None:
            residual = x
            if self.normalize_before:
                x = self.norm_conv(x)
            x = residual + self.dropout(self.conv_module(x))
            if not self.normalize_before:
                x = self.norm_conv(x)
        # feed forward module
        residual = x
        if self.normalize_before:
            x = self.norm_ff(x)
        x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm_ff(x)
        if self.conv_module is not None:
            x = self.norm_final(x)
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        if pos_emb is not None:
            return (x, pos_emb), mask
        return x, mask
class EncoderLayer(nn.Module):
    """Encoder layer module.
    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
            can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        conv_module (torch.nn.Module): Convolution module instance.
            `ConvlutionModule` instance can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
    """
    def __init__(
        self,
        size,
        self_attn_cros_channel,
        self_attn_conformer,
        feed_forward_csa,
        feed_forward_macaron_csa,
        conv_module_csa,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
    ):
        """Construct an EncoderLayer object."""
        super(EncoderLayer, self).__init__()
        self.encoder_cros_channel_atten = self_attn_cros_channel
        self.encoder_csa = Encoder_Conformer_Layer(
                size,
                self_attn_conformer,
                feed_forward_csa,
                feed_forward_macaron_csa,
                conv_module_csa,
                dropout_rate,
                normalize_before,
                concat_after,
                cca_pos=0)
        self.norm_mha = LayerNorm(size)  # for the MHA module
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x_input, mask, channel_size, cache=None):
        """Compute encoded features.
        Args:
            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
                - w/o pos emb: Tensor (#batch, time, size).
            mask (torch.Tensor): Mask tensor for the input (#batch, time).
            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
        Returns:
            torch.Tensor: Output tensor (#batch, time, size).
            torch.Tensor: Mask tensor (#batch, time).
        """
        if isinstance(x_input, tuple):
            x, pos_emb = x_input[0], x_input[1]
        else:
            x, pos_emb = x_input, None
        residual = x
        x = self.norm_mha(x)
        t_leng = x.size(1)
        d_dim = x.size(2)
        x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
        x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
        pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
        pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
        x_pad = torch.cat([pad_before,x_new, pad_after], 1)
        x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
        x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
        x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
        x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
        x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
        x_new = x_new.reshape(-1,channel_size,d_dim)
        x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
        x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
        x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
        x = residual + self.dropout(x_att)
        if pos_emb is not None:
            x_input =  (x, pos_emb)
        else:
            x_input = x
        x_input, mask = self.encoder_csa(x_input, mask)
        return x_input, mask , channel_size
funasr/models/encoder/fsmn_encoder.py
@@ -1,55 +1,50 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class LinearTransform(nn.Module):
    def __init__(self, input_dim, output_dim, quantize=0):
    def __init__(self, input_dim, output_dim):
        super(LinearTransform, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.linear = nn.Linear(input_dim, output_dim, bias=False)
        self.quantize = quantize
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, input):
        if self.quantize:
            output = self.quant(input)
        else:
            output = input
        output = self.linear(output)
        if self.quantize:
            output = self.dequant(output)
        output = self.linear(input)
        return output
class AffineTransform(nn.Module):
    def __init__(self, input_dim, output_dim, quantize=0):
    def __init__(self, input_dim, output_dim):
        super(AffineTransform, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.quantize = quantize
        self.linear = nn.Linear(input_dim, output_dim)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, input):
        if self.quantize:
            output = self.quant(input)
        else:
            output = input
        output = self.linear(output)
        if self.quantize:
            output = self.dequant(output)
        output = self.linear(input)
        return output
class RectifiedLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(RectifiedLinear, self).__init__()
        self.dim = input_dim
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
    def forward(self, input):
        out = self.relu(input)
        return out
class FSMNBlock(nn.Module):
@@ -62,7 +57,6 @@
            rorder=None,
            lstride=1,
            rstride=1,
            quantize=0
    ):
        super(FSMNBlock, self).__init__()
@@ -84,71 +78,75 @@
                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
        else:
            self.conv_right = None
        self.quantize = quantize
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, input):
    def forward(self, input: torch.Tensor, in_cache=None):
        x = torch.unsqueeze(input, 1)
        x_per = x.permute(0, 3, 2, 1)
        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
        if self.quantize:
            y_left = self.quant(y_left)
        x_per = x.permute(0, 3, 2, 1)  # B D T C
        if in_cache is None:  # offline
            y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
        else:
            y_left = torch.cat((in_cache, x_per), dim=2)
            in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
        y_left = self.conv_left(y_left)
        if self.quantize:
            y_left = self.dequant(y_left)
        out = x_per + y_left
        if self.conv_right is not None:
            # maybe need to check
            y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
            y_right = y_right[:, :, self.rstride:, :]
            if self.quantize:
                y_right = self.quant(y_right)
            y_right = self.conv_right(y_right)
            if self.quantize:
                y_right = self.dequant(y_right)
            out += y_right
        out_per = out.permute(0, 3, 2, 1)
        output = out_per.squeeze(1)
        return output
        return output, in_cache
class RectifiedLinear(nn.Module):
class BasicBlock(nn.Sequential):
    def __init__(self,
                 linear_dim: int,
                 proj_dim: int,
                 lorder: int,
                 rorder: int,
                 lstride: int,
                 rstride: int,
                 stack_layer: int
                 ):
        super(BasicBlock, self).__init__()
        self.lorder = lorder
        self.rorder = rorder
        self.lstride = lstride
        self.rstride = rstride
        self.stack_layer = stack_layer
        self.linear = LinearTransform(linear_dim, proj_dim)
        self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
        self.affine = AffineTransform(proj_dim, linear_dim)
        self.relu = RectifiedLinear(linear_dim, linear_dim)
    def __init__(self, input_dim, output_dim):
        super(RectifiedLinear, self).__init__()
        self.dim = input_dim
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
    def forward(self, input):
        out = self.relu(input)
        # out = self.dropout(out)
        return out
    def forward(self, input: torch.Tensor, in_cache=None):
        x1 = self.linear(input)  # B T D
        if in_cache is not None:  # Dict[str, tensor.Tensor]
            cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
            if cache_layer_name not in in_cache:
                in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
            x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
        else:
            x2, _ = self.fsmn_block(x1)
        x3 = self.affine(x2)
        x4 = self.relu(x3)
        return x4, in_cache
def _build_repeats(
        fsmn_layers: int,
        linear_dim: int,
        proj_dim: int,
        lorder: int,
        rorder: int,
        lstride=1,
        rstride=1,
):
    repeats = [
        nn.Sequential(
            LinearTransform(linear_dim, proj_dim),
            FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1),
            AffineTransform(proj_dim, linear_dim),
            RectifiedLinear(linear_dim, linear_dim))
        for i in range(fsmn_layers)
    ]
class FsmnStack(nn.Sequential):
    def __init__(self, *args):
        super(FsmnStack, self).__init__(*args)
    return nn.Sequential(*repeats)
    def forward(self, input: torch.Tensor, in_cache=None):
        x = input
        for module in self._modules.values():
            x, in_cache = module(x, in_cache)
        return x
'''
@@ -177,6 +175,7 @@
            rstride: int,
            output_affine_dim: int,
            output_dim: int,
            streaming=False
    ):
        super(FSMN, self).__init__()
@@ -185,23 +184,16 @@
        self.fsmn_layers = fsmn_layers
        self.linear_dim = linear_dim
        self.proj_dim = proj_dim
        self.lorder = lorder
        self.rorder = rorder
        self.lstride = lstride
        self.rstride = rstride
        self.output_affine_dim = output_affine_dim
        self.output_dim = output_dim
        self.in_cache_original = dict() if streaming else None
        self.in_cache = copy.deepcopy(self.in_cache_original)
        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
        self.relu = RectifiedLinear(linear_dim, linear_dim)
        self.fsmn = _build_repeats(fsmn_layers,
                                   linear_dim,
                                   proj_dim,
                                   lorder, rorder,
                                   lstride, rstride)
        self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
                                range(fsmn_layers)])
        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)
@@ -209,27 +201,29 @@
    def fuse_modules(self):
        pass
    def cache_reset(self):
        self.in_cache = copy.deepcopy(self.in_cache_original)
    def forward(
            self,
            input: torch.Tensor,
            in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
    ) -> torch.Tensor:
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
            input (torch.Tensor): Input tensor (B, T, D)
            in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
        """
        x1 = self.in_linear1(input)
        x2 = self.in_linear2(x1)
        x3 = self.relu(x2)
        x4 = self.fsmn(x3)
        x4 = self.fsmn(x3, self.in_cache)  # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
        x5 = self.out_linear1(x4)
        x6 = self.out_linear2(x5)
        x7 = self.softmax(x6)
        return x7
        # return x6, in_cache
'''
funasr/models/encoder/mfcca_encoder.py
New file
@@ -0,0 +1,450 @@
from typing import Optional
from typing import Tuple
import logging
import torch
from torch import nn
from typeguard import check_argument_types
from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import (
    MultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttention,  # noqa: H301
    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
)
from funasr.modules.embedding import (
    PositionalEncoding,  # noqa: H301
    ScaledPositionalEncoding,  # noqa: H301
    RelPositionalEncoding,  # noqa: H301
    LegacyRelPositionalEncoding,  # noqa: H301
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
    PositionwiseFeedForward,  # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
import pdb
import math
class ConvolutionModule(nn.Module):
    """ConvolutionModule in Conformer model.
    Args:
        channels (int): The number of channels of conv layers.
        kernel_size (int): Kernerl size of conv layers.
    """
    def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
        """Construct an ConvolutionModule object."""
        super(ConvolutionModule, self).__init__()
        # kernerl_size should be a odd number for 'SAME' padding
        assert (kernel_size - 1) % 2 == 0
        self.pointwise_conv1 = nn.Conv1d(
            channels,
            2 * channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=bias,
        )
        self.depthwise_conv = nn.Conv1d(
            channels,
            channels,
            kernel_size,
            stride=1,
            padding=(kernel_size - 1) // 2,
            groups=channels,
            bias=bias,
        )
        self.norm = nn.BatchNorm1d(channels)
        self.pointwise_conv2 = nn.Conv1d(
            channels,
            channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=bias,
        )
        self.activation = activation
    def forward(self, x):
        """Compute convolution module.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, channels).
        Returns:
            torch.Tensor: Output tensor (#batch, time, channels).
        """
        # exchange the temporal dimension and the feature dimension
        x = x.transpose(1, 2)
        # GLU mechanism
        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
        # 1D Depthwise Conv
        x = self.depthwise_conv(x)
        x = self.activation(self.norm(x))
        x = self.pointwise_conv2(x)
        return x.transpose(1, 2)
class MFCCAEncoder(AbsEncoder):
    """Conformer encoder module.
    Args:
        input_size (int): Input dimension.
        output_size (int): Dimention of attention.
        attention_heads (int): The number of heads of multi head attention.
        linear_units (int): The number of units of position-wise feed forward.
        num_blocks (int): The number of decoder blocks.
        dropout_rate (float): Dropout rate.
        attention_dropout_rate (float): Dropout rate in attention.
        positional_dropout_rate (float): Dropout rate after adding positional encoding.
        input_layer (Union[str, torch.nn.Module]): Input layer type.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            If True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            If False, no additional linear will be applied. i.e. x -> x + att(x)
        positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
        positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
        rel_pos_type (str): Whether to use the latest relative positional encoding or
            the legacy one. The legacy relative positional encoding will be deprecated
            in the future. More Details can be found in
            https://github.com/espnet/espnet/pull/2816.
        encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
        encoder_attn_layer_type (str): Encoder attention layer type.
        activation_type (str): Encoder activation function type.
        macaron_style (bool): Whether to use macaron style for positionwise layer.
        use_cnn_module (bool): Whether to use convolution module.
        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
        cnn_module_kernel (int): Kernerl size of convolution module.
        padding_idx (int): Padding idx for input_layer=embed.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: str = "conv2d",
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 3,
        macaron_style: bool = False,
        rel_pos_type: str = "legacy",
        pos_enc_layer_type: str = "rel_pos",
        selfattention_layer_type: str = "rel_selfattn",
        activation_type: str = "swish",
        use_cnn_module: bool = True,
        zero_triu: bool = False,
        cnn_module_kernel: int = 31,
        padding_idx: int = -1,
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size
        if rel_pos_type == "legacy":
            if pos_enc_layer_type == "rel_pos":
                pos_enc_layer_type = "legacy_rel_pos"
            if selfattention_layer_type == "rel_selfattn":
                selfattention_layer_type = "legacy_rel_selfattn"
        elif rel_pos_type == "latest":
            assert selfattention_layer_type != "legacy_rel_selfattn"
            assert pos_enc_layer_type != "legacy_rel_pos"
        else:
            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
        activation = get_activation(activation_type)
        if pos_enc_layer_type == "abs_pos":
            pos_enc_class = PositionalEncoding
        elif pos_enc_layer_type == "scaled_abs_pos":
            pos_enc_class = ScaledPositionalEncoding
        elif pos_enc_layer_type == "rel_pos":
            assert selfattention_layer_type == "rel_selfattn"
            pos_enc_class = RelPositionalEncoding
        elif pos_enc_layer_type == "legacy_rel_pos":
            assert selfattention_layer_type == "legacy_rel_selfattn"
            pos_enc_class = LegacyRelPositionalEncoding
            logging.warning(
                "Using legacy_rel_pos and it will be deprecated in the future."
            )
        else:
            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
        if input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(input_size, output_size),
                torch.nn.LayerNorm(output_size),
                torch.nn.Dropout(dropout_rate),
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d":
            self.embed = Conv2dSubsampling(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d6":
            self.embed = Conv2dSubsampling6(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d8":
            self.embed = Conv2dSubsampling8(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif isinstance(input_layer, torch.nn.Module):
            self.embed = torch.nn.Sequential(
                input_layer,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer is None:
            self.embed = torch.nn.Sequential(
                pos_enc_class(output_size, positional_dropout_rate)
            )
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        self.normalize_before = normalize_before
        if positionwise_layer_type == "linear":
            positionwise_layer = PositionwiseFeedForward
            positionwise_layer_args = (
                output_size,
                linear_units,
                dropout_rate,
                activation,
            )
        elif positionwise_layer_type == "conv1d":
            positionwise_layer = MultiLayeredConv1d
            positionwise_layer_args = (
                output_size,
                linear_units,
                positionwise_conv_kernel_size,
                dropout_rate,
            )
        elif positionwise_layer_type == "conv1d-linear":
            positionwise_layer = Conv1dLinear
            positionwise_layer_args = (
                output_size,
                linear_units,
                positionwise_conv_kernel_size,
                dropout_rate,
            )
        else:
            raise NotImplementedError("Support only linear or conv1d.")
        if selfattention_layer_type == "selfattn":
            encoder_selfattn_layer = MultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
        elif selfattention_layer_type == "legacy_rel_selfattn":
            assert pos_enc_layer_type == "legacy_rel_pos"
            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
            encoder_selfattn_layer_args = (
               attention_heads,
                output_size,
                attention_dropout_rate,
            )
            logging.warning(
                "Using legacy_rel_selfattn and it will be deprecated in the future."
            )
        elif selfattention_layer_type == "rel_selfattn":
            assert pos_enc_layer_type == "rel_pos"
            encoder_selfattn_layer = RelPositionMultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
                zero_triu,
            )
        else:
            raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
        convolution_layer = ConvolutionModule
        convolution_layer_args = (output_size, cnn_module_kernel, activation)
        encoder_selfattn_layer_raw = MultiHeadedAttention
        encoder_selfattn_layer_args_raw = (
            attention_heads,
            output_size,
            attention_dropout_rate,
        )
        self.encoders = repeat(
            num_blocks,
            lambda lnum: EncoderLayer(
                output_size,
                encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
                encoder_selfattn_layer(*encoder_selfattn_layer_args),
                positionwise_layer(*positionwise_layer_args),
                positionwise_layer(*positionwise_layer_args) if macaron_style else None,
                convolution_layer(*convolution_layer_args) if use_cnn_module else None,
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        if self.normalize_before:
            self.after_norm = LayerNorm(output_size)
        self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
        self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
        self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
        self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
    def output_size(self) -> int:
        return self._output_size
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        channel_size: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)
        xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
        if isinstance(xs_pad, tuple):
            xs_pad = xs_pad[0]
        t_leng = xs_pad.size(1)
        d_dim = xs_pad.size(2)
        xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
        #pdb.set_trace()
        if(channel_size<8):
            repeat_num = math.ceil(8/channel_size)
            xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
        xs_pad = self.conv1(xs_pad)
        xs_pad = self.conv2(xs_pad)
        xs_pad = self.conv3(xs_pad)
        xs_pad = self.conv4(xs_pad)
        xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
        mask_tmp = masks.size(1)
        masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
    def forward_hidden(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        else:
            xs_pad = self.embed(xs_pad)
        num_layer = len(self.encoders)
        for idx, encoder in enumerate(self.encoders):
            xs_pad, masks = encoder(xs_pad, masks)
            if idx == num_layer // 2 - 1:
                hidden_feature = xs_pad
        if isinstance(xs_pad, tuple):
            xs_pad = xs_pad[0]
            hidden_feature = hidden_feature[0]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
            self.hidden_feature = self.after_norm(hidden_feature)
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
funasr/models/frontend/default.py
@@ -131,3 +131,128 @@
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
        return input_stft, feats_lens
class MultiChannelFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """
    def __init__(
            self,
            fs: Union[int, str] = 16000,
            n_fft: int = 512,
            win_length: int = None,
            hop_length: int = 128,
            window: Optional[str] = "hann",
            center: bool = True,
            normalized: bool = False,
            onesided: bool = True,
            n_mels: int = 80,
            fmin: int = None,
            fmax: int = None,
            htk: bool = False,
            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
            apply_stft: bool = True,
            frame_length: int = None,
            frame_shift: int = None,
            lfr_m: int = None,
            lfr_n: int = None,
    ):
        assert check_argument_types()
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        self.hop_length = hop_length
        if apply_stft:
            self.stft = Stft(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                center=center,
                window=window,
                normalized=normalized,
                onesided=onesided,
            )
        else:
            self.stft = None
        self.apply_stft = apply_stft
        if frontend_conf is not None:
            self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
        else:
            self.frontend = None
        self.logmel = LogMel(
            fs=fs,
            n_fft=n_fft,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
            htk=htk,
        )
        self.n_mels = n_mels
        self.frontend_type = "multichannelfrontend"
    def output_size(self) -> int:
        return self.n_mels
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        #import pdb;pdb.set_trace()
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
        else:
            if isinstance(input, ComplexTensor):
                input_stft = input
            else:
                input_stft = ComplexTensor(input[..., 0], input[..., 1])
            feats_lens = input_lengths
        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)
        # 4. STFT -> Power spectrum
        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        input_power = input_stft.real ** 2 + input_stft.imag ** 2
        # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
        # input_power: (Batch, [Channel,] Length, Freq)
        #       -> input_feats: (Batch, Length, Dim)
        input_feats, _ = self.logmel(input_power, feats_lens)
        bt = input_feats.size(0)
        if input_feats.dim() ==4:
            channel_size = input_feats.size(2)
            # batch * channel * T * D
            #pdb.set_trace()
            input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
            # input_feats = input_feats.transpose(1,2)
            # batch * channel
            feats_lens = feats_lens.repeat(1,channel_size).squeeze()
        else:
            channel_size = 1
        return input_feats, feats_lens, channel_size
    def _compute_stft(
            self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> torch.Tensor:
        input_stft, feats_lens = self.stft(input, input_lengths)
        assert input_stft.dim() >= 4, input_stft.shape
        # "2" refers to the real/imag parts of Complex
        assert input_stft.shape[-1] == 2, input_stft.shape
        # Change torch.Tensor to ComplexTensor
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
        return input_stft, feats_lens
funasr/models/frontend/wav_frontend.py
@@ -171,10 +171,7 @@
                              window_type=self.window,
                              sample_frequency=self.fs)
            # if self.lfr_m != 1 or self.lfr_n != 1:
            #     mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            # if self.cmvn_file is not None:
            #     mat = apply_cmvn(mat, self.cmvn_file)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)
funasr/runtime/__init__.py
funasr/runtime/python/__init__.py
funasr/runtime/python/onnxruntime/__init__.py
funasr/runtime/python/onnxruntime/paraformer/__init__.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/README.md
New file
@@ -0,0 +1,75 @@
## Using paraformer with ONNXRuntime
<p align="left">
    <a href=""><img src="https://img.shields.io/badge/Python->=3.7,<=3.10-aff.svg"></a>
    <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
</p>
### Introduction
- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
### Steps:
1. Download the whole directory
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer
```
2. Install the related packages.
   ```bash
   pip install -r requirements.txt
   ```
3. Export the model.
   `Tips`: torch 1.11.0 is required.
   ```shell
   python -m funasr.export.export_model [model_name] [export_dir] [true]
   ```
   `model_name`: the model is to export.
   `export_dir`: the dir where the onnx is export.
   More details ref to ([export docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
   - `e.g.`, Export model from modelscope
      ```shell
      python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
      ```
   - `e.g.`, Export model from local path, the model'name must be `model.pb`.
      ```shell
      python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
      ```
5. Run the demo.
   - Model_dir: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`.
   - Input: wav formt file, support formats: `str, np.ndarray, List[str]`
   - Output: `List[str]`: recognition result.
   - Example:
        ```python
        from paraformer_onnx import Paraformer
        model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
        model = Paraformer(model_dir, batch_size=1)
        wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
        result = model(wav_path)
        print(result)
        ```
## Speed
Environment:Intel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz
Test [wav, 5.3s, 100 times avg.](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav)
| Backend |        RTF        |
|:-------:|:-----------------:|
| Pytorch |       0.110       |
|  Onnx   |       0.038       |
## Acknowledge
1. We acknowledge [SWHL](https://github.com/RapidAI/RapidASR) for contributing the onnxruntime(python api).
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py
New file
@@ -0,0 +1,3 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/demo.py
New file
@@ -0,0 +1,9 @@
from paraformer_onnx import Paraformer
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1)
wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
result = model(wav_path)
print(result)
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
New file
@@ -0,0 +1,144 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
import librosa
import numpy as np
from utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                    OrtInferSession, TokenIDConverter, get_logger,
                    read_yaml)
from utils.postprocess_utils import sentence_postprocess
from utils.frontend import WavFrontend
logging = get_logger()
class Paraformer():
    def __init__(self, model_dir: Union[str, Path]=None,
                 batch_size: int = 1,
                 device_id: Union[str, int]="-1",
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
        model_file = os.path.join(model_dir, 'model.onnx')
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(model_file, device_id)
        self.batch_size = batch_size
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        waveform_list = self.load_data(wav_content, fs)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            try:
                am_scores, valid_token_lens = self.infer(feats, feats_len)
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                preds = self.decode(am_scores, valid_token_lens)
            asr_res.extend(preds)
        return asr_res
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
        if isinstance(wav_content, np.ndarray):
            return [wav_content]
        if isinstance(wav_content, str):
            return [load_wav(wav_content)]
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')
    def extract_feat(self,
                     waveform_list: List[np.ndarray]
                     ) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)
        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len
    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, 'constant', constant_values=0)
        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats
    def infer(self, feats: np.ndarray,
              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        am_scores, token_nums = self.ort_infer([feats, feats_len])
        return am_scores, token_nums
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]
    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)
        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        hyp = Hypothesis(yseq=yseq, score=score)
        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()
        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-1]
        texts = sentence_postprocess(token)
        text = texts[0]
        # text = self.tokenizer.tokens2text(token)
        return text
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/requirements.txt
New file
@@ -0,0 +1,6 @@
librosa
numpy
onnxruntime
scipy
typeguard
kaldi-native-fbank
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/__init__.py
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/frontend.py
New file
@@ -0,0 +1,136 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
from typeguard import check_argument_types
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class WavFrontend():
    """Conventional frontend structure for ASR.
    """
    def __init__(
            self,
            cmvn_file: str = None,
            fs: int = 16000,
            window: str = 'hamming',
            n_mels: int = 80,
            frame_length: int = 25,
            frame_shift: int = 10,
            filter_length_min: int = -1,
            filter_length_max: float = -1,
            lfr_m: int = 1,
            lfr_n: int = 1,
            dither: float = 1.0
    ) -> None:
        check_argument_types()
        opts = knf.FbankOptions()
        opts.frame_opts.samp_freq = fs
        opts.frame_opts.dither = dither
        opts.frame_opts.window_type = window
        opts.frame_opts.frame_shift_ms = float(frame_shift)
        opts.frame_opts.frame_length_ms = float(frame_length)
        opts.mel_opts.num_bins = n_mels
        opts.energy_floor = 0
        opts.frame_opts.snip_edges = True
        opts.mel_opts.debug_mel = False
        self.opts = opts
        self.filter_length_min = filter_length_min
        self.filter_length_max = filter_length_max
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
        self.cmvn_file = cmvn_file
        if self.cmvn_file:
            self.cmvn = self.load_cmvn()
    def fbank(self,
              waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        waveform = waveform * (1 << 15)
        fbank_fn = knf.OnlineFbank(self.opts)
        fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
        frames = fbank_fn.num_frames_ready
        mat = np.empty([frames, self.opts.mel_opts.num_bins])
        for i in range(frames):
            mat[i, :] = fbank_fn.get_frame(i)
        feat = mat.astype(np.float32)
        feat_len = np.array(mat.shape[0]).astype(np.int32)
        return feat, feat_len
    def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        if self.lfr_m != 1 or self.lfr_n != 1:
            feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
        if self.cmvn_file:
            feat = self.apply_cmvn(feat)
        feat_len = np.array(feat.shape[0]).astype(np.int32)
        return feat, feat_len
    @staticmethod
    def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
        LFR_inputs = []
        T = inputs.shape[0]
        T_lfr = int(np.ceil(T / lfr_n))
        left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
        inputs = np.vstack((left_padding, inputs))
        T = T + (lfr_m - 1) // 2
        for i in range(T_lfr):
            if lfr_m <= T - i * lfr_n:
                LFR_inputs.append(
                    (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
            else:
                # process last LFR frame
                num_padding = lfr_m - (T - i * lfr_n)
                frame = inputs[i * lfr_n:].reshape(-1)
                for _ in range(num_padding):
                    frame = np.hstack((frame, inputs[-1]))
                LFR_inputs.append(frame)
        LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
        return LFR_outputs
    def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
        """
        Apply CMVN with mvn data
        """
        frame, dim = inputs.shape
        means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
        vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
        inputs = (inputs + means) * vars
        return inputs
    def load_cmvn(self,) -> np.ndarray:
        with open(self.cmvn_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        means_list = []
        vars_list = []
        for i in range(len(lines)):
            line_item = lines[i].split()
            if line_item[0] == '<AddShift>':
                line_item = lines[i + 1].split()
                if line_item[0] == '<LearnRateCoef>':
                    add_shift_line = line_item[3:(len(line_item) - 1)]
                    means_list = list(add_shift_line)
                    continue
            elif line_item[0] == '<Rescale>':
                line_item = lines[i + 1].split()
                if line_item[0] == '<LearnRateCoef>':
                    rescale_line = line_item[3:(len(line_item) - 1)]
                    vars_list = list(rescale_line)
                    continue
        means = np.array(means_list).astype(np.float64)
        vars = np.array(vars_list).astype(np.float64)
        cmvn = np.array([means, vars])
        return cmvn
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/postprocess_utils.py
New file
@@ -0,0 +1,240 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
    if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
        return True
    return False
def isAllChinese(word: Union[List[Any], str]):
    word_lists = []
    for i in word:
        cur = i.replace(' ', '')
        cur = cur.replace('</s>', '')
        cur = cur.replace('<s>', '')
        word_lists.append(cur)
    if len(word_lists) == 0:
        return False
    for ch in word_lists:
        if isChinese(ch) is False:
            return False
    return True
def isAllAlpha(word: Union[List[Any], str]):
    word_lists = []
    for i in word:
        cur = i.replace(' ', '')
        cur = cur.replace('</s>', '')
        cur = cur.replace('<s>', '')
        word_lists.append(cur)
    if len(word_lists) == 0:
        return False
    for ch in word_lists:
        if ch.isalpha() is False and ch != "'":
            return False
        elif ch.isalpha() is True and isChinese(ch) is True:
            return False
    return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
    words_size = len(words)
    word_lists = []
    abbr_begin = []
    abbr_end = []
    last_num = -1
    ts_lists = []
    ts_nums = []
    ts_index = 0
    for num in range(words_size):
        if num <= last_num:
            continue
        if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
            if num + 1 < words_size and words[
                    num + 1] == ' ' and num + 2 < words_size and len(
                        words[num +
                              2]) == 1 and words[num +
                                                 2].encode('utf-8').isalpha():
                # found the begin of abbr
                abbr_begin.append(num)
                num += 2
                abbr_end.append(num)
                # to find the end of abbr
                while True:
                    num += 1
                    if num < words_size and words[num] == ' ':
                        num += 1
                        if num < words_size and len(
                                words[num]) == 1 and words[num].encode(
                                    'utf-8').isalpha():
                            abbr_end.pop()
                            abbr_end.append(num)
                            last_num = num
                        else:
                            break
                    else:
                        break
    for num in range(words_size):
        if words[num] == ' ':
            ts_nums.append(ts_index)
        else:
            ts_nums.append(ts_index)
            ts_index += 1
    last_num = -1
    for num in range(words_size):
        if num <= last_num:
            continue
        if num in abbr_begin:
            if time_stamp is not None:
                begin = time_stamp[ts_nums[num]][0]
            word_lists.append(words[num].upper())
            num += 1
            while num < words_size:
                if num in abbr_end:
                    word_lists.append(words[num].upper())
                    last_num = num
                    break
                else:
                    if words[num].encode('utf-8').isalpha():
                        word_lists.append(words[num].upper())
                num += 1
            if time_stamp is not None:
                end = time_stamp[ts_nums[num]][1]
                ts_lists.append([begin, end])
        else:
            word_lists.append(words[num])
            if time_stamp is not None and words[num] != ' ':
                begin = time_stamp[ts_nums[num]][0]
                end = time_stamp[ts_nums[num]][1]
                ts_lists.append([begin, end])
                begin = end
    if time_stamp is not None:
        return word_lists, ts_lists
    else:
        return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
    middle_lists = []
    word_lists = []
    word_item = ''
    ts_lists = []
    # wash words lists
    for i in words:
        word = ''
        if isinstance(i, str):
            word = i
        else:
            word = i.decode('utf-8')
        if word in ['<s>', '</s>', '<unk>']:
            continue
        else:
            middle_lists.append(word)
    # all chinese characters
    if isAllChinese(middle_lists):
        for i, ch in enumerate(middle_lists):
            word_lists.append(ch.replace(' ', ''))
        if time_stamp is not None:
            ts_lists = time_stamp
    # all alpha characters
    elif isAllAlpha(middle_lists):
        ts_flag = True
        for i, ch in enumerate(middle_lists):
            if ts_flag and time_stamp is not None:
                begin = time_stamp[i][0]
                end = time_stamp[i][1]
            word = ''
            if '@@' in ch:
                word = ch.replace('@@', '')
                word_item += word
                if time_stamp is not None:
                    ts_flag = False
                    end = time_stamp[i][1]
            else:
                word_item += ch
                word_lists.append(word_item)
                word_lists.append(' ')
                word_item = ''
                if time_stamp is not None:
                    ts_flag = True
                    end = time_stamp[i][1]
                    ts_lists.append([begin, end])
                    begin = end
    # mix characters
    else:
        alpha_blank = False
        ts_flag = True
        begin = -1
        end = -1
        for i, ch in enumerate(middle_lists):
            if ts_flag and time_stamp is not None:
                begin = time_stamp[i][0]
                end = time_stamp[i][1]
            word = ''
            if isAllChinese(ch):
                if alpha_blank is True:
                    word_lists.pop()
                word_lists.append(ch)
                alpha_blank = False
                if time_stamp is not None:
                    ts_flag = True
                    ts_lists.append([begin, end])
                    begin = end
            elif '@@' in ch:
                word = ch.replace('@@', '')
                word_item += word
                alpha_blank = False
                if time_stamp is not None:
                    ts_flag = False
                    end = time_stamp[i][1]
            elif isAllAlpha(ch):
                word_item += ch
                word_lists.append(word_item)
                word_lists.append(' ')
                word_item = ''
                alpha_blank = True
                if time_stamp is not None:
                    ts_flag = True
                    end = time_stamp[i][1]
                    ts_lists.append([begin, end])
                    begin = end
            else:
                raise ValueError('invalid character: {}'.format(ch))
    if time_stamp is not None:
        word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
        real_word_lists = []
        for ch in word_lists:
            if ch != ' ':
                real_word_lists.append(ch)
        sentence = ' '.join(real_word_lists).strip()
        return sentence, ts_lists, real_word_lists
    else:
        word_lists = abbr_dispose(word_lists)
        real_word_lists = []
        for ch in word_lists:
            if ch != ' ':
                real_word_lists.append(ch)
        sentence = ''.join(word_lists).strip()
        return sentence, real_word_lists
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils/utils.py
New file
@@ -0,0 +1,256 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import functools
import logging
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
import yaml
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
                         SessionOptions, get_available_providers, get_device)
from typeguard import check_argument_types
import warnings
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class TokenIDConverter():
    def __init__(self, token_list: Union[List, str],
                 ):
        check_argument_types()
        # self.token_list = self.load_token(token_path)
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
    # @staticmethod
    # def load_token(file_path: Union[Path, str]) -> List:
    #     if not Path(file_path).exists():
    #         raise TokenIDConverterError(f'The {file_path} does not exist.')
    #
    #     with open(str(file_path), 'rb') as f:
    #         token_list = pickle.load(f)
    #
    #     if len(token_list) != len(set(token_list)):
    #         raise TokenIDConverterError('The Token exists duplicated symbol.')
    #     return token_list
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self,
                   integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise TokenIDConverterError(
                f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        token2id = {v: i for i, v in enumerate(self.token_list)}
        if self.unk_symbol not in token2id:
            raise TokenIDConverterError(
                f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
            )
        unk_id = token2id[self.unk_symbol]
        return [token2id.get(i, unk_id) for i in tokens]
class CharTokenizer():
    def __init__(
        self,
        symbol_value: Union[Path, str, Iterable[str]] = None,
        space_symbol: str = "<space>",
        remove_non_linguistic_symbols: bool = False,
    ):
        check_argument_types()
        self.space_symbol = space_symbol
        self.non_linguistic_symbols = self.load_symbols(symbol_value)
        self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
    @staticmethod
    def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
        if value is None:
            return set()
        if isinstance(value, Iterable[str]):
            return set(value)
        file_path = Path(value)
        if not file_path.exists():
            logging.warning("%s doesn't exist.", file_path)
            return set()
        with file_path.open("r", encoding="utf-8") as f:
            return set(line.rstrip() for line in f)
    def text2tokens(self, line: Union[str, list]) -> List[str]:
        tokens = []
        while len(line) != 0:
            for w in self.non_linguistic_symbols:
                if line.startswith(w):
                    if not self.remove_non_linguistic_symbols:
                        tokens.append(line[: len(w)])
                    line = line[len(w):]
                    break
            else:
                t = line[0]
                if t == " ":
                    t = "<space>"
                tokens.append(t)
                line = line[1:]
        return tokens
    def tokens2text(self, tokens: Iterable[str]) -> str:
        tokens = [t if t != self.space_symbol else " " for t in tokens]
        return "".join(tokens)
    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f'space_symbol="{self.space_symbol}"'
            f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
            f")"
        )
class Hypothesis(NamedTuple):
    """Hypothesis data type."""
    yseq: np.ndarray
    score: Union[float, np.ndarray] = 0
    scores: Dict[str, Union[float, np.ndarray]] = dict()
    states: Dict[str, Any] = dict()
    def asdict(self) -> dict:
        """Convert data to JSON-friendly dict."""
        return self._replace(
            yseq=self.yseq.tolist(),
            score=float(self.score),
            scores={k: float(v) for k, v in self.scores.items()},
        )._asdict()
class TokenIDConverterError(Exception):
    pass
class ONNXRuntimeError(Exception):
    pass
class OrtInferSession():
    def __init__(self, model_file, device_id=-1):
        sess_opt = SessionOptions()
        sess_opt.log_severity_level = 4
        sess_opt.enable_cpu_mem_arena = False
        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        cuda_ep = 'CUDAExecutionProvider'
        cuda_provider_options = {
            "device_id": device_id,
            "arena_extend_strategy": "kNextPowerOfTwo",
            "cudnn_conv_algo_search": "EXHAUSTIVE",
            "do_copy_in_default_stream": "true",
        }
        cpu_ep = 'CPUExecutionProvider'
        cpu_provider_options = {
            "arena_extend_strategy": "kSameAsRequested",
        }
        EP_list = []
        if device_id != -1 and get_device() == 'GPU' \
                and cuda_ep in get_available_providers():
            EP_list = [(cuda_ep, cuda_provider_options)]
        EP_list.append((cpu_ep, cpu_provider_options))
        self._verify_model(model_file)
        self.session = InferenceSession(model_file,
                                        sess_options=sess_opt,
                                        providers=EP_list)
        if device_id != -1 and cuda_ep not in self.session.get_providers():
            warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                          'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                          'you can check their relations from the offical web site: '
                          'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
                          RuntimeWarning)
    def __call__(self,
                 input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
        input_dict = dict(zip(self.get_input_names(), input_content))
        try:
            return self.session.run(None, input_dict)
        except Exception as e:
            raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
    def get_input_names(self, ):
        return [v.name for v in self.session.get_inputs()]
    def get_output_names(self,):
        return [v.name for v in self.session.get_outputs()]
    def get_character_list(self, key: str = 'character'):
        return self.meta_dict[key].splitlines()
    def have_key(self, key: str = 'character') -> bool:
        self.meta_dict = self.session.get_modelmeta().custom_metadata_map
        if key in self.meta_dict.keys():
            return True
        return False
    @staticmethod
    def _verify_model(model_path):
        model_path = Path(model_path)
        if not model_path.exists():
            raise FileNotFoundError(f'{model_path} does not exists.')
        if not model_path.is_file():
            raise FileExistsError(f'{model_path} is not a file.')
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():
        raise FileExistsError(f'The {yaml_path} does not exist.')
    with open(str(yaml_path), 'rb') as f:
        data = yaml.load(f, Loader=yaml.Loader)
    return data
@functools.lru_cache()
def get_logger(name='rapdi_paraformer'):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added.
    Args:
        name (str): Logger name.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger
    formatter = logging.Formatter(
        '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
        datefmt="%Y/%m/%d %H:%M:%S")
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    logger_initialized[name] = True
    logger.propagate = False
    return logger
funasr/runtime/python/torchscripts/__init__.py
funasr/runtime/python/torchscripts/paraformer/__init__.py
funasr/tasks/abs_task.py
@@ -71,7 +71,7 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_int
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
try:
@@ -1153,6 +1153,14 @@
                if args.batch_bins is not None:
                    args.batch_bins = args.batch_bins * args.ngpu
        # filter samples if wav.scp and text are mismatch
        if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                filter_wav_text(args.data_dir, args.train_set)
                filter_wav_text(args.data_dir, args.dev_set)
            if args.simple_ddp:
                dist.barrier()
        if args.train_shape_file is None and args.dataset_type == "small":
            if not args.simple_ddp or distributed_option.dist_rank == 0:
                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
funasr/tasks/asr.py
@@ -40,6 +40,7 @@
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -47,8 +48,10 @@
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
@@ -86,6 +89,7 @@
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
        multichannelfrontend=MultiChannelFrontend,
    ),
    type_check=AbsFrontend,
    default="default",
@@ -119,6 +123,7 @@
        paraformer_bert=ParaformerBert,
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
        mfcca=MFCCA,
    ),
    type_check=AbsESPnetModel,
    default="asr",
@@ -142,6 +147,7 @@
        sanm=SANMEncoder,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        mfcca_enc=MFCCAEncoder,
    ),
    type_check=AbsEncoder,
    default="rnn",
@@ -1106,3 +1112,135 @@
        var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
class ASRTaskMFCCA(ASRTask):
    # If you need more than one optimizers, change this value
    num_optimizers: int = 1
    # Add variable objects configurations
    class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --specaug and --specaug_conf
        specaug_choices,
        # --normalize and --normalize_conf
        normalize_choices,
        # --model and --model_conf
        model_choices,
        # --preencoder and --preencoder_conf
        preencoder_choices,
        # --encoder and --encoder_conf
        encoder_choices,
        # --decoder and --decoder_conf
        decoder_choices,
    ]
    # If you need to modify train() or eval() procedures, change Trainer class here
    trainer = Trainer
    @classmethod
    def build_model(cls, args: argparse.Namespace):
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]
            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend':
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None
        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
        else:
            normalize = None
        # 4. Pre-encoder input block
        # NOTE(kan-bayashi): Use getattr to keep the compatibility
        if getattr(args, "preencoder", None) is not None:
            preencoder_class = preencoder_choices.get_class(args.preencoder)
            preencoder = preencoder_class(**args.preencoder_conf)
            input_size = preencoder.output_size()
        else:
            preencoder = None
        # 5. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
        # 7. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder.output_size(),
            **args.decoder_conf,
        )
        # 8. CTC
        ctc = CTC(
            odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
        )
        # 10. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
        rnnt_decoder = None
        # 8. Build model
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            rnnt_decoder=rnnt_decoder,
            token_list=token_list,
            **args.model_conf,
        )
        # 11. Initialize
        if args.init is not None:
            initialize(model, args.init)
        assert check_return_type(model)
        return model
funasr/tasks/vad.py
@@ -235,7 +235,7 @@
            cls, args: argparse.Namespace, train: bool
    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
        assert check_argument_types()
        #if args.use_preprocessor:
        # if args.use_preprocessor:
        #    retval = CommonPreprocessor(
        #        train=train,
        #        # NOTE(kamo): Check attribute existence for backward compatibility
@@ -254,7 +254,7 @@
        #        if hasattr(args, "rir_scp")
        #        else None,
        #    )
        #else:
        # else:
        #    retval = None
        retval = None
        assert check_return_type(retval)
@@ -291,7 +291,8 @@
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("e2evad")
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf,
                            streaming=args.encoder_conf.get('streaming', False))
        return model
funasr/utils/timestamp_tools.py
@@ -4,88 +4,6 @@
import numpy as np
from typing import Any, List, Tuple, Union
def cut_interval(alphas: torch.Tensor, start: int, end: int, tail: bool):
    if not tail:
        if end == start + 1:
            cut = (end + start) / 2.0
        else:
            alpha = alphas[start+1: end].tolist()
            reverse_steps = 1
            for reverse_alpha in alpha[::-1]:
                if reverse_alpha > 0.35:
                    reverse_steps += 1
                else:
                    break
            cut = end - reverse_steps
    else:
        if end != len(alphas) - 1:
            cut = end + 1
        else:
            cut = start + 1
    return float(cut)
def time_stamp_lfr6(alphas: torch.Tensor, speech_lengths: torch.Tensor, raw_text: List[str], begin: int = 0, end: int = None):
    time_stamp_list = []
    alphas = alphas[0]
    text = copy.deepcopy(raw_text)
    if end is None:
        time = speech_lengths * 60 / 1000
        sacle_rate = (time / speech_lengths[0]).tolist()
    else:
        time = (end - begin) / 1000
        sacle_rate = (time / speech_lengths[0]).tolist()
    predictor = (alphas > 0.5).int()
    fire_places = torch.nonzero(predictor == 1).squeeze(1).tolist()
    cuts = []
    npeak = int(predictor.sum())
    nchar = len(raw_text)
    if npeak - 1 == nchar:
        fire_places = torch.where((alphas > 0.5) == 1)[0].tolist()
        for i in range(len(fire_places)):
            if fire_places[i] < len(alphas) - 1:
                if 0.05 < alphas[fire_places[i]+1] < 0.5:
                    fire_places[i] += 1
    elif npeak < nchar:
        lost_num = nchar - npeak
        lost_fire = speech_lengths[0].tolist() - fire_places[-1]
        interval_distance = lost_fire // (lost_num + 1)
        for i in range(1, lost_num + 1):
            fire_places.append(fire_places[-1] + interval_distance)
    elif npeak - 1 > nchar:
        redundance_num = npeak - 1 - nchar
        for i in range(redundance_num):
            fire_places.pop()
    cuts.append(0)
    start_sil = True
    if start_sil:
        text.insert(0, '<sil>')
    for i in range(len(fire_places)-1):
        cuts.append(cut_interval(alphas, fire_places[i], fire_places[i+1], tail=(i==len(fire_places)-2)))
    for i in range(2, len(fire_places)-2):
        if fire_places[i-2] == fire_places[i-1] - 1 and fire_places[i-1] != fire_places[i] - 1:
            cuts[i-1] += 1
    if cuts[-1] != len(alphas) - 1:
        text.append('<sil>')
        cuts.append(speech_lengths[0].tolist())
    cuts.insert(-1, (cuts[-1] + cuts[-2]) * 0.5)
    sec_fire_places = np.array(cuts) * sacle_rate
    for i in range(1, len(sec_fire_places) - 1):
        start, end = sec_fire_places[i], sec_fire_places[i+1]
        if i == len(sec_fire_places) - 2:
            end = time
        time_stamp_list.append([int(round(start, 2) * 1000) + begin, int(round(end, 2) * 1000) + begin])
        text = text[1:]
    if npeak - 1 == nchar or npeak > nchar:
        return time_stamp_list[:-1]
    else:
        return time_stamp_list
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
    START_END_THRESHOLD = 5
    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
funasr/utils/wav_utils.py
@@ -287,3 +287,35 @@
            wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
            text_path = os.path.join(split_dir, str(i + 1), "text")
            f_data.write(wav_path + " " + text_path + "\n")
def filter_wav_text(data_dir, dataset):
    wav_file = os.path.join(data_dir,dataset,"wav.scp")
    text_file = os.path.join(data_dir, dataset, "text")
    with open(wav_file) as f_wav, open(text_file) as f_text:
        wav_lines = f_wav.readlines()
        text_lines = f_text.readlines()
    os.rename(wav_file, "{}.bak".format(wav_file))
    os.rename(text_file, "{}.bak".format(text_file))
    wav_dict = {}
    for line in wav_lines:
        parts = line.strip().split()
        if len(parts) < 2:
            continue
        sample_name, wav_path = parts
        wav_dict[sample_name] = wav_path
    text_dict = {}
    for line in text_lines:
        parts = line.strip().split(" ", 1)
        if len(parts) < 2:
            continue
        sample_name, txt = parts
        text_dict[sample_name] = txt
    filter_count = 0
    with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
        for sample_name, wav_path in wav_dict.items():
            if sample_name in text_dict.keys():
                f_wav.write(sample_name + " " + wav_path  + "\n")
                f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
            else:
                filter_count += 1
    print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))