nichongjia-2007
2023-07-20 17e8f5b889be2ad31608b5203dc5fbc5fd5c0f8a
Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
26个文件已修改
2个文件已删除
22个文件已添加
3084 ■■■■■ 已修改文件
.github/workflows/UnitTest.yml 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.github/workflows/main.yml 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 123 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 30 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/README.md 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/conf/decode_asr_transformer.yaml 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/conf/train_asr_branchformer.yaml 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/local/aishell_data_prep.sh 66 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/local/download_and_untar.sh 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/path.sh 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/run.sh 225 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/branchformer/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml 101 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/local/aishell_data_prep.sh 66 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/local/download_and_untar.sh 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/path.sh 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/run.sh 225 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/e_branchformer/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/speaker_diarization/TEMPLATE/README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo_long.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_infer.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_train.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/vad_inference_launch.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_asr_model.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/build_dataloader.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/tokenize.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/preprocessor.py 70 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/branchformer_encoder.py 545 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/e_branchformer_encoder.py 465 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/cgmlp.py 124 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/fastformer.py 153 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/repeat.py 21 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/html5/demo.gif 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/html5/readme.md 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/html5/readme_cn.md 135 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/html5/readme_zh.md 93 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/third_party/download_ffmpeg.sh 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/third_party/download_onnxruntime.sh 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/funasr_wss_client.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/funasr_wss_server.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/readme.md 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/readme_zh.md 190 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train/trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/version.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.github/workflows/UnitTest.yml
@@ -6,9 +6,7 @@
        - main
  push:
    branches:
      - dev_wjm
      - dev_jy
      - dev_wjm_infer
jobs:
  build:
.github/workflows/main.yml
@@ -5,7 +5,6 @@
      - main
  push:
    branches:
      - dev_wjm
      - main
      - dev_lyh
README.md
@@ -14,33 +14,60 @@
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new) 
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Usage**](#usage)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md)
| [**Quick Start**](#quick-start)
| [**Runtime**](./funasr/runtime/readme.md)
| [**Model Zoo**](./docs/model_zoo/modelscope_models.md)
| [**Contact**](#contact)
| [**M2MET2.0 Challenge**](https://github.com/alibaba-damo-academy/FunASR#multi-channel-multi-party-meeting-transcription-20-m2met20-challenge)
<a name="whats-new"></a>
## What's new: 
### FunASR runtime-SDK
### FunASR runtime
- 2023.07.03: 
We have release the FunASR runtime-SDK-0.1.0, file transcription service (Mandarin) is now supported ([ZH](funasr/runtime/readme_cn.md)/[EN](funasr/runtime/readme.md))
### Multi-Channel Multi-Party Meeting Transcription 2.0 (M2MeT2.0) Challenge
We are pleased to announce that the M2MeT2.0 challenge has been accepted by the ASRU 2023 challenge special session. The registration is now open. The baseline system is conducted on FunASR and is provided as a receipe of AliMeeting corpus. For more details you can see the guidence of M2MET2.0 ([CN](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html)/[EN](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)).
Challenge details ref to ([CN](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html)/[EN](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html))
### Release notes
### Speech Recognition
- Academic Models
  - Encoder-Decoder Models (AED): [Transformer](egs/aishell/transformer), [Conformer](egs/aishell/conformer), [Branchformer](egs/aishell/branchformer)
  - Transducer Models (RNNT): [RNNT streaming](egs/aishell/rnnt), [BAT streaming/non-streaming](egs/aishell/bat)
  - Non-autoregressive Model (NAR): [Paraformer](egs/aishell/paraformer)
  - Multi-speaker recognition model: [MFCCA](egs_modelscope/asr/mfcca)
For the release notes, please ref to [news](https://github.com/alibaba-damo-academy/FunASR/releases)
- Industrial-level Models
  - Paraformer Models (Mandarin): [Paraformer-large](egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch), [Paraformer-large-long](egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch), [Paraformer-large streaming](egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online), [Paraformer-large-contextual](egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404)
  - Conformer Models (English): [Conformer]()
  - UniASR streaming offline unifying models: [16k UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary), [16k UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary), [16k UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary), [8k UniASR Mandarin financial domain](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-finance-vocab3445-online/summary), [16k UniASR Mandarin audio-visual domain](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-audio_and_video-vocab3445-online/summary),
  [Southern Fujian Dialect model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary), [French model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/summary),  [German model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/summary),  [Vietnamese model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/summary),  [Persian model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/summary)
- Speaker Recognition
  - Speaker Verification Model: [xvector](egs_modelscope/speaker_verification)
  - Speaker Diarization Model: [SOND](egs/callhome/diarization/sond)
- Punctuation Restoration
  - Chinese Punctuation Model: [CT-Transformer](egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch), [CT-Transformer streaming](egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727)
- Endpoint Detection
  - [FSMN-VAD](egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common)
- Timestamp Prediction
  - Character-level FA Model: [TP-Aligner](egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline)
<a name="highlights"></a>
## Highlights
- FunASR is a fundamental speech recognition toolkit that offers a variety of features, including speech recognition (ASR), Voice Activity Detection (VAD), Punctuation Restoration, Language Models, Speaker Verification, Speaker diarization and multi-talker ASR.
- We have released a vast collection of academic and industrial pretrained models on the [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition), which can be accessed through our [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md). The representative [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) model has achieved SOTA performance in many speech recognition tasks. 
- FunASR offers a user-friendly pipeline for fine-tuning pretrained models from the [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition). Additionally, the optimized dataloader in FunASR enables faster training speeds for large-scale datasets. This feature enhances the efficiency of the speech recognition process for researchers and practitioners.
<a name="Installation"></a>
## Installation
Install from pip
@@ -70,24 +97,60 @@
For more details, please ref to [installation](https://alibaba-damo-academy.github.io/FunASR/en/installation/installation.html)
## Usage
<a name="quick-start"></a>
## Quick Start
You could use FunASR by:
You can use FunASR in the following ways:
- egs
- egs_modelscope
- runtime
- Service Deployment SDK
- Industrial model egs
- Academic model egs
### egs
If you want to train the model from scratch, you could use funasr directly by recipe, as the following:
### Service Deployment SDK
#### Python version Example
Supports real-time streaming speech recognition, uses non-streaming models for error correction, and outputs text with punctuation. Currently, only single client is supported. For multi-concurrency, please refer to the C++ version service deployment SDK below.
##### Server Deployment
```shell
cd egs/aishell/paraformer
. ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
cd funasr/runtime/python/websocket
python funasr_wss_server.py --port 10095
```
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
### egs_modelscope
If you want to infer or finetune pretraining models from modelscope, you could use funasr by modelscope pipeline, as the following:
##### Client Testing
```shell
python funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
```
For more examples, please refer to [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2).
#### C++ version Example
Currently, offline file transcription service (CPU) is supported, and concurrent requests of hundreds of channels are supported.
##### Server Deployment
You can use the following command to complete the deployment with one click:
```shell
curl -O https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/shell/funasr-runtime-deploy-offline-cpu-zh.sh
sudo bash funasr-runtime-deploy-offline-cpu-zh.sh install --workspace ./funasr-runtime-resources
```
##### Client Testing
```shell
python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.wav"
```
For more examples, please refer to [docs](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/docs/SDK_tutorial_zh.md)
### Industrial Model Egs
If you want to use the pre-trained industrial models in ModelScope for inference or fine-tuning training, you can refer to the following command:
```python
from modelscope.pipelines import pipeline
@@ -102,24 +165,20 @@
print(rec_result)
# {'text': '欢迎大家来体验达摩院推出的语音识别模型'}
```
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
### runtime
### Academic model egs
An example with websocket:
If you want to train from scratch, usually for academic models, you can start training and inference with the following command:
For the server:
```shell
cd funasr/runtime/python/websocket
python funasr_wss_server.py --port 10095
cd egs/aishell/paraformer
. ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
```
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
For the client:
```shell
python funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
#python funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
```
More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2)
<a name="contact"></a>
## Contact
If you have any questions about FunASR, please contact us by
README_zh.md
@@ -37,11 +37,33 @@
详情请参考文档([点击此处](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html))
### 学术模型更新
### 语音识别
### 工业模型更新
- 学术模型:
  - Encoder-Decoder模型:[Transformer](egs/aishell/transformer),[Conformer](egs/aishell/conformer),[Branchformer](egs/aishell/branchformer)
  - Transducer模型:[RNNT(流式)](egs/aishell/rnnt),[BAT](egs/aishell/bat)
  - 非自回归模型:[Paraformer](egs/aishell/paraformer)
  - 多说话人识别模型:[MFCCA](egs_modelscope/asr/mfcca)
- 工业模型:
  - 中文通用模型:[Paraformer-large](egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch),[Paraformer-large长音频版本](egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch),[Paraformer-large流式版本](egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online)
  - 中文通用热词模型:[Paraformer-large-contextual](egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404),
  - 英文通用模型:[Conformer]()
  - 流式离线一体化模型: [16k UniASR闽南语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary)、 [16k UniASR法语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/summary)、 [16k UniASR德语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/summary)、 [16k UniASR越南语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/summary)、 [16k UniASR波斯语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/summary),
  [16k UniASR缅甸语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary)、      [16k UniASR希伯来语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary)、      [16k UniASR乌尔都语](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary)、      [8k UniASR中文金融领域](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-finance-vocab3445-online/summary)、[16k UniASR中文音视频领域](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-audio_and_video-vocab3445-online/summary)
### 说话人识别
  - 说话人确认模型:[xvector](egs_modelscope/speaker_verification)
  - 说话人日志模型:[SOND](egs/callhome/diarization/sond)
- 2023/07/06
### 标点恢复
  - 中文标点模型:[CT-Transformer](egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch),[CT-Transformer流式](egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727)
### 端点检测
  - [FSMN-VAD](egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common)
### 时间戳预测
  - 字级别模型:[TP-Aligner](egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline)
<a name="核心功能"></a>
## 核心功能
@@ -180,7 +202,7 @@
## 许可协议
项目遵循[The MIT License](https://opensource.org/licenses/MIT)开源协议. 工业模型许可协议请参考([点击此处](./MODEL_LICENSE))
项目遵循[The MIT License](https://opensource.org/licenses/MIT)开源协议。 工业模型许可协议请参考([点击此处](./MODEL_LICENSE))
## Stargazers over time
docs/README.md
@@ -4,9 +4,9 @@
For convenience, we provide users with the ability to generate local HTML manually.
First, you should install the following packages, which is required for building HTML:
```sh
conda activate funasr
pip install requests sphinx nbsphinx sphinx_markdown_tables sphinx_rtd_theme recommonmark
pip3 install -U "funasr[docs]"
```
Then you can generate HTML manually.
egs/aishell/branchformer/conf/decode_asr_transformer.yaml
New file
@@ -0,0 +1,6 @@
beam_size: 10
penalty: 0.0
maxlenratio: 0.0
minlenratio: 0.0
ctc_weight: 0.4
lm_weight: 0.0
egs/aishell/branchformer/conf/train_asr_branchformer.yaml
New file
@@ -0,0 +1,104 @@
# network architecture
# encoder related
encoder: branchformer
encoder_conf:
    output_size: 256
    use_attn: true
    attention_heads: 4
    attention_layer_type: rel_selfattn
    pos_enc_layer_type: rel_pos
    rel_pos_type: latest
    use_cgmlp: true
    cgmlp_linear_units: 2048
    cgmlp_conv_kernel: 31
    use_linear_after_conv: false
    gate_activation: identity
    merge_method: concat
    cgmlp_weight: 0.5               # used only if merge_method is "fixed_ave"
    attn_branch_drop_rate: 0.0      # used only if merge_method is "learned_ave"
    num_blocks: 24
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: conv2d
    stochastic_depth_rate: 0.0
# decoder related
decoder: transformer
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.
    src_attention_dropout_rate: 0.
# frontend related
frontend: wav_frontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
# hybrid CTC/attention
model_conf:
    ctc_weight: 0.3
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: false
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 180
val_scheduler_criterion:
    - valid
    - acc
best_model_criterion:
-   - valid
    - acc
    - max
keep_nbest_models: 10
optim: adam
optim_conf:
   lr: 0.001
   weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 35000
specaug: specaug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 27
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_ratio_range:
    - 0.
    - 0.05
    num_time_mask: 10
dataset_conf:
    data_names: speech,text
    data_types: sound,text
    shuffle: True
    shuffle_conf:
        shuffle_size: 2048
        sort_size: 500
    batch_conf:
        batch_type: token
        batch_size: 10000
    num_workers: 8
log_interval: 50
normalize: None
egs/aishell/branchformer/local/aishell_data_prep.sh
New file
@@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
  echo "Usage: $0 <audio-path> <text-path> <output-path>"
  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
  exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
  echo "Error: $0 requires two directory arguments"
  exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
  echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
  echo Preparing $dir transcriptions
  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
  sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;
egs/aishell/branchformer/local/download_and_untar.sh
New file
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
#             2017  Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
  remove_archive=true
  shift
fi
if [ $# -ne 3 ]; then
  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
  echo "$0: no such directory $data"
  exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
  if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
  exit 1;
fi
if [ -z "$url" ]; then
  echo "$0: empty URL base."
  exit 1;
fi
if [ -f $data/$part/.complete ]; then
  echo "$0: data part $part was already successfully extracted, nothing to do."
  exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
  size_ok=false
  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
  if ! $size_ok; then
    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
    echo "does not equal the size of one of the archives."
    rm $data/$part.tgz
  else
    echo "$data/$part.tgz exists and appears to be complete."
  fi
fi
if [ ! -f $data/$part.tgz ]; then
  if ! command -v wget >/dev/null; then
    echo "$0: wget is not installed."
    exit 1;
  fi
  full_url=$url/$part.tgz
  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
  cd $data || exit 1
  if ! wget --no-check-certificate $full_url; then
    echo "$0: error executing wget $full_url"
    exit 1;
  fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
  echo "$0: error un-tarring archive $data/$part.tgz"
  exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
  cd $data/$part/wav || exit 1
  for wav in ./*.tar.gz; do
    echo "Extracting wav from $wav"
    tar -zxf $wav && rm $wav
  done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
  rm $data/$part.tgz
fi
exit 0;
egs/aishell/branchformer/path.sh
New file
@@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
egs/aishell/branchformer/run.sh
New file
@@ -0,0 +1,225 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1,2,3"
gpu_num=4
count=1
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=0
stop_stage=5
# feature configuration
feats_dim=80
nj=64
# data
raw_data=../raw_data
data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_asr_branchformer.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
    inference_nj=$[${ngpu}*${njob}]
    _ngpu=1
else
    inference_nj=$njob
    _ngpu=0
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
    echo "stage -1: Data Download"
    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Data preparation"
    # Data preparation
    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
    for x in train dev test; do
        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
            > ${feats_dir}/data/${x}/text
        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
    done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Feature and CMVN Generation"
    utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
fi
token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: Dictionary Preparation"
    mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
    echo "make a dictionary"
    echo "<blank>" > ${token_list}
    echo "<s>" >> ${token_list}
    echo "</s>" >> ${token_list}
    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
fi
# LM Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Training"
fi
# ASR Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    echo "stage 4: ASR Training"
    mkdir -p ${exp_dir}/exp/${model_dir}
    mkdir -p ${exp_dir}/exp/${model_dir}/log
    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
    if [ -f $INIT_FILE ];then
        rm -f $INIT_FILE
    fi
    init_method=file://$(readlink -f $INIT_FILE)
    echo "$0: init method is $init_method"
    for ((i = 0; i < $gpu_num; ++i)); do
        {
            rank=$i
            local_rank=$i
            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
            train.py \
                --task_name asr \
                --gpu_id $gpu_id \
                --use_preprocessor true \
                --token_type $token_type \
                --token_list $token_list \
                --data_dir ${feats_dir}/data \
                --train_set ${train_set} \
                --valid_set ${valid_set} \
                --data_file_names "wav.scp,text" \
                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
                --speed_perturb ${speed_perturb} \
                --resume true \
                --output_dir ${exp_dir}/exp/${model_dir} \
                --config $asr_config \
                --ngpu $gpu_num \
                --num_worker_count $count \
                --dist_init_method $init_method \
                --dist_world_size $world_size \
                --dist_rank $rank \
                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
        } &
        done
        wait
fi
# Testing Stage
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
    echo "stage 5: Inference"
    for dset in ${test_sets}; do
        asr_exp=${exp_dir}/exp/${model_dir}
        inference_tag="$(basename "${inference_config}" .yaml)"
        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
        _logdir="${_dir}/logdir"
        if [ -d ${_dir} ]; then
            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
            exit 0
        fi
        mkdir -p "${_logdir}"
        _data="${feats_dir}/data/${dset}"
        key_file=${_data}/${scp}
        num_scp_file="$(<${key_file} wc -l)"
        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
        split_scps=
        for n in $(seq "${_nj}"); do
            split_scps+=" ${_logdir}/keys.${n}.scp"
        done
        # shellcheck disable=SC2086
        utils/split_scp.pl "${key_file}" ${split_scps}
        _opts=
        if [ -n "${inference_config}" ]; then
            _opts+="--config ${inference_config} "
        fi
        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
            python -m funasr.bin.asr_inference_launch \
                --batch_size 1 \
                --ngpu "${_ngpu}" \
                --njob ${njob} \
                --gpuid_list ${gpuid_list} \
                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
                --key_file "${_logdir}"/keys.JOB.scp \
                --asr_train_config "${asr_exp}"/config.yaml \
                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
                --output_dir "${_logdir}"/output.JOB \
                --mode asr \
                ${_opts}
        for f in token token_int score text; do
            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
                for i in $(seq "${_nj}"); do
                    cat "${_logdir}/output.${i}/1best_recog/${f}"
                done | sort -k1 >"${_dir}/${f}"
            fi
        done
        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
        python utils/proce_text.py ${_data}/text ${_data}/text.proc
        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
        cat ${_dir}/text.cer.txt
    done
fi
# Prepare files for ModelScope fine-tuning and inference
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
    echo "stage 6: ModelScope Preparation"
    cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
    vocab_size=$(cat ${token_list} | wc -l)
    python utils/gen_modelscope_configuration.py \
        --am_model_name $inference_asr_model \
        --mode asr \
        --model_name conformer \
        --dataset aishell \
        --output_dir $exp_dir/exp/$model_dir \
        --vocab_size $vocab_size \
        --tag $tag
fi
egs/aishell/branchformer/utils
New file
@@ -0,0 +1 @@
../transformer/utils
egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml
New file
@@ -0,0 +1,6 @@
beam_size: 10
penalty: 0.0
maxlenratio: 0.0
minlenratio: 0.0
ctc_weight: 0.4
lm_weight: 0.0
egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml
New file
@@ -0,0 +1,101 @@
# network architecture
# encoder related
encoder: e_branchformer
encoder_conf:
    output_size: 256
    attention_heads: 4
    attention_layer_type: rel_selfattn
    pos_enc_layer_type: rel_pos
    rel_pos_type: latest
    cgmlp_linear_units: 1024
    cgmlp_conv_kernel: 31
    use_linear_after_conv: false
    gate_activation: identity
    num_blocks: 12
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.1
    input_layer: conv2d
    layer_drop_rate: 0.0
    linear_units: 1024
    positionwise_layer_type: linear
    use_ffn: true
    macaron_ffn: true
    merge_conv_kernel: 31
# decoder related
decoder: transformer
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.
    src_attention_dropout_rate: 0.
# frontend related
frontend: wav_frontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
# hybrid CTC/attention
model_conf:
    ctc_weight: 0.3
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: false
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 180
best_model_criterion:
-   - valid
    - acc
    - max
keep_nbest_models: 10
optim: adam
optim_conf:
   lr: 0.001
   weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 35000
specaug: specaug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 27
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_ratio_range:
    - 0.
    - 0.05
    num_time_mask: 10
dataset_conf:
    data_names: speech,text
    data_types: sound,text
    shuffle: True
    shuffle_conf:
        shuffle_size: 2048
        sort_size: 500
    batch_conf:
        batch_type: token
        batch_size: 10000
    num_workers: 8
log_interval: 50
normalize: None
egs/aishell/e_branchformer/local/aishell_data_prep.sh
New file
@@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
  echo "Usage: $0 <audio-path> <text-path> <output-path>"
  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
  exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
  echo "Error: $0 requires two directory arguments"
  exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
  echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
  echo Preparing $dir transcriptions
  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
  sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;
egs/aishell/e_branchformer/local/download_and_untar.sh
New file
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
#             2017  Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
  remove_archive=true
  shift
fi
if [ $# -ne 3 ]; then
  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
  echo "$0: no such directory $data"
  exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
  if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
  exit 1;
fi
if [ -z "$url" ]; then
  echo "$0: empty URL base."
  exit 1;
fi
if [ -f $data/$part/.complete ]; then
  echo "$0: data part $part was already successfully extracted, nothing to do."
  exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
  size_ok=false
  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
  if ! $size_ok; then
    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
    echo "does not equal the size of one of the archives."
    rm $data/$part.tgz
  else
    echo "$data/$part.tgz exists and appears to be complete."
  fi
fi
if [ ! -f $data/$part.tgz ]; then
  if ! command -v wget >/dev/null; then
    echo "$0: wget is not installed."
    exit 1;
  fi
  full_url=$url/$part.tgz
  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
  cd $data || exit 1
  if ! wget --no-check-certificate $full_url; then
    echo "$0: error executing wget $full_url"
    exit 1;
  fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
  echo "$0: error un-tarring archive $data/$part.tgz"
  exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
  cd $data/$part/wav || exit 1
  for wav in ./*.tar.gz; do
    echo "Extracting wav from $wav"
    tar -zxf $wav && rm $wav
  done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
  rm $data/$part.tgz
fi
exit 0;
egs/aishell/e_branchformer/path.sh
New file
@@ -0,0 +1,5 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
egs/aishell/e_branchformer/run.sh
New file
@@ -0,0 +1,225 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="0,1,2,3"
gpu_num=4
count=1
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="0.9 1.0 1.1"
stage=0
stop_stage=5
# feature configuration
feats_dim=80
nj=64
# data
raw_data=../raw_data
data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
asr_config=conf/train_asr_e_branchformer.yaml
model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
    inference_nj=$[${ngpu}*${njob}]
    _ngpu=1
else
    inference_nj=$njob
    _ngpu=0
fi
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
    echo "stage -1: Data Download"
    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Data preparation"
    # Data preparation
    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
    for x in train dev test; do
        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
            > ${feats_dir}/data/${x}/text
        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
    done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Feature and CMVN Generation"
    utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
fi
token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: Dictionary Preparation"
    mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
    echo "make a dictionary"
    echo "<blank>" > ${token_list}
    echo "<s>" >> ${token_list}
    echo "</s>" >> ${token_list}
    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
fi
# LM Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Training"
fi
# ASR Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    echo "stage 4: ASR Training"
    mkdir -p ${exp_dir}/exp/${model_dir}
    mkdir -p ${exp_dir}/exp/${model_dir}/log
    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
    if [ -f $INIT_FILE ];then
        rm -f $INIT_FILE
    fi
    init_method=file://$(readlink -f $INIT_FILE)
    echo "$0: init method is $init_method"
    for ((i = 0; i < $gpu_num; ++i)); do
        {
            rank=$i
            local_rank=$i
            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
            train.py \
                --task_name asr \
                --gpu_id $gpu_id \
                --use_preprocessor true \
                --token_type $token_type \
                --token_list $token_list \
                --data_dir ${feats_dir}/data \
                --train_set ${train_set} \
                --valid_set ${valid_set} \
                --data_file_names "wav.scp,text" \
                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
                --speed_perturb ${speed_perturb} \
                --resume true \
                --output_dir ${exp_dir}/exp/${model_dir} \
                --config $asr_config \
                --ngpu $gpu_num \
                --num_worker_count $count \
                --dist_init_method $init_method \
                --dist_world_size $world_size \
                --dist_rank $rank \
                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
        } &
        done
        wait
fi
# Testing Stage
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
    echo "stage 5: Inference"
    for dset in ${test_sets}; do
        asr_exp=${exp_dir}/exp/${model_dir}
        inference_tag="$(basename "${inference_config}" .yaml)"
        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
        _logdir="${_dir}/logdir"
        if [ -d ${_dir} ]; then
            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
            exit 0
        fi
        mkdir -p "${_logdir}"
        _data="${feats_dir}/data/${dset}"
        key_file=${_data}/${scp}
        num_scp_file="$(<${key_file} wc -l)"
        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
        split_scps=
        for n in $(seq "${_nj}"); do
            split_scps+=" ${_logdir}/keys.${n}.scp"
        done
        # shellcheck disable=SC2086
        utils/split_scp.pl "${key_file}" ${split_scps}
        _opts=
        if [ -n "${inference_config}" ]; then
            _opts+="--config ${inference_config} "
        fi
        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
            python -m funasr.bin.asr_inference_launch \
                --batch_size 1 \
                --ngpu "${_ngpu}" \
                --njob ${njob} \
                --gpuid_list ${gpuid_list} \
                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
                --key_file "${_logdir}"/keys.JOB.scp \
                --asr_train_config "${asr_exp}"/config.yaml \
                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
                --output_dir "${_logdir}"/output.JOB \
                --mode asr \
                ${_opts}
        for f in token token_int score text; do
            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
                for i in $(seq "${_nj}"); do
                    cat "${_logdir}/output.${i}/1best_recog/${f}"
                done | sort -k1 >"${_dir}/${f}"
            fi
        done
        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
        python utils/proce_text.py ${_data}/text ${_data}/text.proc
        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
        cat ${_dir}/text.cer.txt
    done
fi
# Prepare files for ModelScope fine-tuning and inference
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
    echo "stage 6: ModelScope Preparation"
    cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
    vocab_size=$(cat ${token_list} | wc -l)
    python utils/gen_modelscope_configuration.py \
        --am_model_name $inference_asr_model \
        --mode asr \
        --model_name conformer \
        --dataset aishell \
        --output_dir $exp_dir/exp/$model_dir \
        --vocab_size $vocab_size \
        --tag $tag
fi
egs/aishell/e_branchformer/utils
New file
@@ -0,0 +1 @@
../transformer/utils
egs_modelscope/speaker_diarization/TEMPLATE/README.md
@@ -3,7 +3,7 @@
> **Note**: 
> The modelscope pipeline supports all the models in 
[model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) 
to inference and finetine. Here we take the model of xvector_sv as example to demonstrate the usage.
to inference and finetune. Here we take the model of xvector_sv as example to demonstrate the usage.
## Inference with pipeline
### Quick start
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo_long.py
@@ -8,9 +8,9 @@
import soundfile as sf
param_dict = dict()
param_dict['hotword'] = "信诺"
param_dict['hotword'] = "你的热词"
test_wav = '/Users/shixian/Downloads/tpdebug.wav'
test_wav = 'YOUR_LONG_WAV.wav'
output_dir = './tmp'
os.system("mkdir -p {}".format(output_dir))
funasr/bin/asr_inference_launch.py
@@ -370,7 +370,7 @@
            results = speech2text(**batch)
            if len(results) < 1:
                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
                results = [[" ", ["sil"], [2], hyp, 10, 6, []]] * nbest
            time_end = time.time()
            forward_time = time_end - time_beg
            lfr_factor = results[0][-1]
@@ -439,6 +439,7 @@
        logging.info(rtf_avg)
        if writer is not None:
            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
        torch.cuda.empty_cache()
        return asr_result_list
    return _forward
@@ -564,6 +565,7 @@
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
@@ -646,8 +648,7 @@
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][
                    0]) < batch_size_token_ms:
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < speech2vadsegment.vad_model.vad_opts.max_single_segment_time:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
@@ -730,6 +731,7 @@
                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        torch.cuda.empty_cache()
        return asr_result_list
    return _forward
@@ -1327,7 +1329,6 @@
        right_context: Number of frames in right context AFTER subsampling.
        display_partial_hypotheses: Whether to display partial hypotheses.
    """
    # assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
@@ -1339,7 +1340,7 @@
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1:
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
@@ -1370,10 +1371,7 @@
        left_context=left_context,
        right_context=right_context,
    )
    speech2text = Speech2TextTransducer.from_pretrained(
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    speech2text = Speech2TextTransducer(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
funasr/bin/punc_infer.py
@@ -8,6 +8,7 @@
import numpy as np
import torch
import os
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
@@ -41,6 +42,11 @@
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
        self.seg_dict_file = None
        self.seg_jieba = False
        if "seg_jieba" in train_args:
            self.seg_jieba = train_args.seg_jieba
            self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
            train=False,
            token_type=train_args.token_type,
@@ -50,6 +56,8 @@
            g2p_type=train_args.g2p,
            text_name="text",
            non_linguistic_symbols=train_args.non_linguistic_symbols,
            seg_jieba=self.seg_jieba,
            seg_dict_file=self.seg_dict_file
        )
    @torch.no_grad()
funasr/bin/punc_train.py
@@ -44,4 +44,10 @@
    else:
        args.distributed = False
    if args.dataset_type == "small":
        if args.batch_size is not None:
            args.batch_size = args.batch_size * args.ngpu * args.num_worker_count
        if args.batch_bins is not None:
            args.batch_bins = args.batch_bins * args.ngpu * args.num_worker_count
    main(args=args)
funasr/bin/vad_inference_launch.py
@@ -123,7 +123,7 @@
                vad_results.append(item)
                if writer is not None:
                    ibest_writer["text"][keys[i]] = "{}".format(results[i])
        torch.cuda.empty_cache()
        return vad_results
    return _forward
funasr/build_utils/build_asr_model.py
@@ -39,6 +39,8 @@
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@@ -113,6 +115,8 @@
        sanm=SANMEncoder,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        branchformer=BranchformerEncoder,
        e_branchformer=EBranchformerEncoder,
        mfcca_enc=MFCCAEncoder,
        chunk_conformer=ConformerChunkEncoder,
    ),
funasr/datasets/large_datasets/build_dataloader.py
@@ -69,12 +69,15 @@
            symbol_table = read_symbol_table(args.token_list)
        if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
            seg_dict = load_seg_dict(args.seg_dict_file)
        if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
            punc_dict = read_symbol_table(args.punc_dict_file)
        if hasattr(args, "punc_list") and args.punc_list is not None:
            punc_dict = read_symbol_table(args.punc_list)
        if hasattr(args, "bpemodel") and args.bpemodel is not None:
            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
        self.dataset_conf = args.dataset_conf
        self.frontend_conf = args.frontend_conf
        if "frontend_conf" not in args:
            self.frontend_conf =  None
        else:
            self.frontend_conf = args.frontend_conf
        self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None 
        logging.info("dataloader config: {}".format(self.dataset_conf))
        batch_mode = self.dataset_conf.get("batch_mode", "padding")
funasr/datasets/large_datasets/dataset.py
@@ -229,15 +229,15 @@
                           mode=mode, 
                           )
    filter_conf = conf.get('filter_conf', {})
    filter_fn = partial(filter, **filter_conf)
    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
    if "text" in data_names:
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
        tokenize_fn = partial(tokenize, **vocab)
        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
    filter_conf = conf.get('filter_conf', {})
    filter_fn = partial(filter, **filter_conf)
    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
    if shuffle:
        buffer_conf = conf.get('shuffle_conf', {})
        buffer_size = buffer_conf['shuffle_size']
funasr/datasets/large_datasets/utils/tokenize.py
@@ -54,9 +54,9 @@
    length = len(text)
    if 'hw_tag' in data:
        pre_index = None
        if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0:
            # enable preset hotword detect in sampling
            pre_index = None
            for hw in hw_config['pre_hwlist']:
                hw = " ".join(seg_tokenize(hw, seg_dict))
                _find = " ".join(text).find(hw)
funasr/datasets/preprocessor.py
@@ -12,6 +12,7 @@
import scipy.signal
import soundfile
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
from funasr.text.token_id_converter import TokenIDConverter
@@ -628,6 +629,7 @@
            text_name: str = "text",
            split_text_name: str = "split_text",
            split_with_space: bool = False,
            seg_jieba: bool = False,
            seg_dict_file: str = None,
    ):
        super().__init__(
@@ -655,6 +657,10 @@
        )
        # The data field name for split text.
        self.split_text_name = split_text_name
        self.seg_jieba = seg_jieba
        if self.seg_jieba:
            import jieba
            jieba.load_userdict(seg_dict_file)
    @classmethod
    def split_words(cls, text: str):
@@ -677,12 +683,73 @@
                words.append(current_word)
        return words
    @classmethod
    def isEnglish(cls, text:str):
        if re.search('^[a-zA-Z\']+$', text):
            return True
        else:
            return False
    @classmethod
    def join_chinese_and_english(cls, input_list):
        line = ''
        for token in input_list:
            if cls.isEnglish(token):
                line = line + ' ' + token
            else:
                line = line + token
        line = line.strip()
        return line
    @classmethod
    def split_words_jieba(cls, text: str):
        input_list = text.split()
        token_list_all = []
        langauge_list = []
        token_list_tmp = []
        language_flag = None
        for token in input_list:
            if cls.isEnglish(token) and language_flag == 'Chinese':
                token_list_all.append(token_list_tmp)
                langauge_list.append('Chinese')
                token_list_tmp = []
            elif not cls.isEnglish(token) and language_flag == 'English':
                token_list_all.append(token_list_tmp)
                langauge_list.append('English')
                token_list_tmp = []
            token_list_tmp.append(token)
            if cls.isEnglish(token):
                language_flag = 'English'
            else:
                language_flag = 'Chinese'
        if token_list_tmp:
            token_list_all.append(token_list_tmp)
            langauge_list.append(language_flag)
        result_list = []
        for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
            if language_flag == 'English':
                result_list.extend(token_list_tmp)
            else:
                seg_list = jieba.cut(cls.join_chinese_and_english(token_list_tmp), HMM=False)
                result_list.extend(seg_list)
        return result_list
    def __call__(
            self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
    ) -> Dict[str, Union[list, np.ndarray]]:
        # Split words.
        if isinstance(data[self.text_name], str):
            split_text = self.split_words(data[self.text_name])
            if self.seg_jieba:
  #              jieba.load_userdict(seg_dict_file)
                split_text = self.split_words_jieba(data[self.text_name])
            else:
                split_text = self.split_words(data[self.text_name])
        else:
            split_text = data[self.text_name]
        data[self.text_name] = " ".join(split_text)
@@ -782,7 +849,6 @@
    ) -> Dict[str, np.ndarray]:
        for i in range(self.num_tokenizer):
            text_name = self.text_name[i]
            #import pdb; pdb.set_trace()
            if text_name in data and self.tokenizer[i] is not None:
                text = data[text_name]
                text = self.text_cleaner(text)
funasr/models/encoder/branchformer_encoder.py
New file
@@ -0,0 +1,545 @@
# Copyright 2022 Yifan Peng (Carnegie Mellon University)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""Branchformer encoder definition.
Reference:
    Yifan Peng, Siddharth Dalmia, Ian Lane, and Shinji Watanabe,
    “Branchformer: Parallel MLP-Attention Architectures to Capture
    Local and Global Context for Speech Recognition and Understanding,”
    in Proceedings of ICML, 2022.
"""
import logging
from typing import List, Optional, Tuple, Union
import numpy
import torch
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.cgmlp import ConvolutionalGatingMLP
from funasr.modules.fastformer import FastSelfAttention
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import (  # noqa: H301
    LegacyRelPositionMultiHeadedAttention,
    MultiHeadedAttention,
    RelPositionMultiHeadedAttention,
)
from funasr.modules.embedding import (  # noqa: H301
    LegacyRelPositionalEncoding,
    PositionalEncoding,
    RelPositionalEncoding,
    ScaledPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import (
    Conv2dSubsampling,
    Conv2dSubsampling2,
    Conv2dSubsampling6,
    Conv2dSubsampling8,
    TooShortUttError,
    check_short_utt,
)
class BranchformerEncoderLayer(torch.nn.Module):
    """Branchformer encoder layer module.
    Args:
        size (int): model dimension
        attn: standard self-attention or efficient attention, optional
        cgmlp: ConvolutionalGatingMLP, optional
        dropout_rate (float): dropout probability
        merge_method (str): concat, learned_ave, fixed_ave
        cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
            used if merge_method is fixed_ave
        attn_branch_drop_rate (float): probability of dropping the attn branch,
            used if merge_method is learned_ave
        stochastic_depth_rate (float): stochastic depth probability
    """
    def __init__(
        self,
        size: int,
        attn: Optional[torch.nn.Module],
        cgmlp: Optional[torch.nn.Module],
        dropout_rate: float,
        merge_method: str,
        cgmlp_weight: float = 0.5,
        attn_branch_drop_rate: float = 0.0,
        stochastic_depth_rate: float = 0.0,
    ):
        super().__init__()
        assert (attn is not None) or (
            cgmlp is not None
        ), "At least one branch should be valid"
        self.size = size
        self.attn = attn
        self.cgmlp = cgmlp
        self.merge_method = merge_method
        self.cgmlp_weight = cgmlp_weight
        self.attn_branch_drop_rate = attn_branch_drop_rate
        self.stochastic_depth_rate = stochastic_depth_rate
        self.use_two_branches = (attn is not None) and (cgmlp is not None)
        if attn is not None:
            self.norm_mha = LayerNorm(size)  # for the MHA module
        if cgmlp is not None:
            self.norm_mlp = LayerNorm(size)  # for the MLP module
        self.norm_final = LayerNorm(size)  # for the final output of the block
        self.dropout = torch.nn.Dropout(dropout_rate)
        if self.use_two_branches:
            if merge_method == "concat":
                self.merge_proj = torch.nn.Linear(size + size, size)
            elif merge_method == "learned_ave":
                # attention-based pooling for two branches
                self.pooling_proj1 = torch.nn.Linear(size, 1)
                self.pooling_proj2 = torch.nn.Linear(size, 1)
                # linear projections for calculating merging weights
                self.weight_proj1 = torch.nn.Linear(size, 1)
                self.weight_proj2 = torch.nn.Linear(size, 1)
                # linear projection after weighted average
                self.merge_proj = torch.nn.Linear(size, size)
            elif merge_method == "fixed_ave":
                assert (
                    0.0 <= cgmlp_weight <= 1.0
                ), "cgmlp weight should be between 0.0 and 1.0"
                # remove the other branch if only one branch is used
                if cgmlp_weight == 0.0:
                    self.use_two_branches = False
                    self.cgmlp = None
                    self.norm_mlp = None
                elif cgmlp_weight == 1.0:
                    self.use_two_branches = False
                    self.attn = None
                    self.norm_mha = None
                # linear projection after weighted average
                self.merge_proj = torch.nn.Linear(size, size)
            else:
                raise ValueError(f"unknown merge method: {merge_method}")
        else:
            self.merge_proj = torch.nn.Identity()
    def forward(self, x_input, mask, cache=None):
        """Compute encoded features.
        Args:
            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
                - w/o pos emb: Tensor (#batch, time, size).
            mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
        Returns:
            torch.Tensor: Output tensor (#batch, time, size).
            torch.Tensor: Mask tensor (#batch, time).
        """
        if cache is not None:
            raise NotImplementedError("cache is not None, which is not tested")
        if isinstance(x_input, tuple):
            x, pos_emb = x_input[0], x_input[1]
        else:
            x, pos_emb = x_input, None
        skip_layer = False
        # with stochastic depth, residual connection `x + f(x)` becomes
        # `x <- x + 1 / (1 - p) * f(x)` at training time.
        stoch_layer_coeff = 1.0
        if self.training and self.stochastic_depth_rate > 0:
            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
        if skip_layer:
            if cache is not None:
                x = torch.cat([cache, x], dim=1)
            if pos_emb is not None:
                return (x, pos_emb), mask
            return x, mask
        # Two branches
        x1 = x
        x2 = x
        # Branch 1: multi-headed attention module
        if self.attn is not None:
            x1 = self.norm_mha(x1)
            if isinstance(self.attn, FastSelfAttention):
                x_att = self.attn(x1, mask)
            else:
                if pos_emb is not None:
                    x_att = self.attn(x1, x1, x1, pos_emb, mask)
                else:
                    x_att = self.attn(x1, x1, x1, mask)
            x1 = self.dropout(x_att)
        # Branch 2: convolutional gating mlp
        if self.cgmlp is not None:
            x2 = self.norm_mlp(x2)
            if pos_emb is not None:
                x2 = (x2, pos_emb)
            x2 = self.cgmlp(x2, mask)
            if isinstance(x2, tuple):
                x2 = x2[0]
            x2 = self.dropout(x2)
        # Merge two branches
        if self.use_two_branches:
            if self.merge_method == "concat":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(torch.cat([x1, x2], dim=-1))
                )
            elif self.merge_method == "learned_ave":
                if (
                    self.training
                    and self.attn_branch_drop_rate > 0
                    and torch.rand(1).item() < self.attn_branch_drop_rate
                ):
                    # Drop the attn branch
                    w1, w2 = 0.0, 1.0
                else:
                    # branch1
                    score1 = (
                        self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5
                    )  # (batch, 1, time)
                    if mask is not None:
                        min_value = float(
                            numpy.finfo(
                                torch.tensor(0, dtype=score1.dtype).numpy().dtype
                            ).min
                        )
                        score1 = score1.masked_fill(mask.eq(0), min_value)
                        score1 = torch.softmax(score1, dim=-1).masked_fill(
                            mask.eq(0), 0.0
                        )
                    else:
                        score1 = torch.softmax(score1, dim=-1)
                    pooled1 = torch.matmul(score1, x1).squeeze(1)  # (batch, size)
                    weight1 = self.weight_proj1(pooled1)  # (batch, 1)
                    # branch2
                    score2 = (
                        self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5
                    )  # (batch, 1, time)
                    if mask is not None:
                        min_value = float(
                            numpy.finfo(
                                torch.tensor(0, dtype=score2.dtype).numpy().dtype
                            ).min
                        )
                        score2 = score2.masked_fill(mask.eq(0), min_value)
                        score2 = torch.softmax(score2, dim=-1).masked_fill(
                            mask.eq(0), 0.0
                        )
                    else:
                        score2 = torch.softmax(score2, dim=-1)
                    pooled2 = torch.matmul(score2, x2).squeeze(1)  # (batch, size)
                    weight2 = self.weight_proj2(pooled2)  # (batch, 1)
                    # normalize weights of two branches
                    merge_weights = torch.softmax(
                        torch.cat([weight1, weight2], dim=-1), dim=-1
                    )  # (batch, 2)
                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
                        -1
                    )  # (batch, 2, 1, 1)
                    w1, w2 = merge_weights[:, 0], merge_weights[:, 1]  # (batch, 1, 1)
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(w1 * x1 + w2 * x2)
                )
            elif self.merge_method == "fixed_ave":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(
                        (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
                    )
                )
            else:
                raise RuntimeError(f"unknown merge method: {self.merge_method}")
        else:
            if self.attn is None:
                x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
            elif self.cgmlp is None:
                x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
            else:
                # This should not happen
                raise RuntimeError("Both branches are not None, which is unexpected.")
        x = self.norm_final(x)
        if pos_emb is not None:
            return (x, pos_emb), mask
        return x, mask
class BranchformerEncoder(AbsEncoder):
    """Branchformer encoder module."""
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        use_attn: bool = True,
        attention_heads: int = 4,
        attention_layer_type: str = "rel_selfattn",
        pos_enc_layer_type: str = "rel_pos",
        rel_pos_type: str = "latest",
        use_cgmlp: bool = True,
        cgmlp_linear_units: int = 2048,
        cgmlp_conv_kernel: int = 31,
        use_linear_after_conv: bool = False,
        gate_activation: str = "identity",
        merge_method: str = "concat",
        cgmlp_weight: Union[float, List[float]] = 0.5,
        attn_branch_drop_rate: Union[float, List[float]] = 0.0,
        num_blocks: int = 12,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "conv2d",
        zero_triu: bool = False,
        padding_idx: int = -1,
        stochastic_depth_rate: Union[float, List[float]] = 0.0,
    ):
        super().__init__()
        self._output_size = output_size
        if rel_pos_type == "legacy":
            if pos_enc_layer_type == "rel_pos":
                pos_enc_layer_type = "legacy_rel_pos"
            if attention_layer_type == "rel_selfattn":
                attention_layer_type = "legacy_rel_selfattn"
        elif rel_pos_type == "latest":
            assert attention_layer_type != "legacy_rel_selfattn"
            assert pos_enc_layer_type != "legacy_rel_pos"
        else:
            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
        if pos_enc_layer_type == "abs_pos":
            pos_enc_class = PositionalEncoding
        elif pos_enc_layer_type == "scaled_abs_pos":
            pos_enc_class = ScaledPositionalEncoding
        elif pos_enc_layer_type == "rel_pos":
            assert attention_layer_type == "rel_selfattn"
            pos_enc_class = RelPositionalEncoding
        elif pos_enc_layer_type == "legacy_rel_pos":
            assert attention_layer_type == "legacy_rel_selfattn"
            pos_enc_class = LegacyRelPositionalEncoding
            logging.warning(
                "Using legacy_rel_pos and it will be deprecated in the future."
            )
        else:
            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
        if input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(input_size, output_size),
                torch.nn.LayerNorm(output_size),
                torch.nn.Dropout(dropout_rate),
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d":
            self.embed = Conv2dSubsampling(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d2":
            self.embed = Conv2dSubsampling2(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d6":
            self.embed = Conv2dSubsampling6(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "conv2d8":
            self.embed = Conv2dSubsampling8(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif isinstance(input_layer, torch.nn.Module):
            self.embed = torch.nn.Sequential(
                input_layer,
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer is None:
            if input_size == output_size:
                self.embed = None
            else:
                self.embed = torch.nn.Linear(input_size, output_size)
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        if attention_layer_type == "selfattn":
            encoder_selfattn_layer = MultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
        elif attention_layer_type == "legacy_rel_selfattn":
            assert pos_enc_layer_type == "legacy_rel_pos"
            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
            logging.warning(
                "Using legacy_rel_selfattn and it will be deprecated in the future."
            )
        elif attention_layer_type == "rel_selfattn":
            assert pos_enc_layer_type == "rel_pos"
            encoder_selfattn_layer = RelPositionMultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
                zero_triu,
            )
        elif attention_layer_type == "fast_selfattn":
            assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
            encoder_selfattn_layer = FastSelfAttention
            encoder_selfattn_layer_args = (
                output_size,
                attention_heads,
                attention_dropout_rate,
            )
        else:
            raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
        cgmlp_layer = ConvolutionalGatingMLP
        cgmlp_layer_args = (
            output_size,
            cgmlp_linear_units,
            cgmlp_conv_kernel,
            dropout_rate,
            use_linear_after_conv,
            gate_activation,
        )
        if isinstance(stochastic_depth_rate, float):
            stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
        if len(stochastic_depth_rate) != num_blocks:
            raise ValueError(
                f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
                f"should be equal to num_blocks ({num_blocks})"
            )
        if isinstance(cgmlp_weight, float):
            cgmlp_weight = [cgmlp_weight] * num_blocks
        if len(cgmlp_weight) != num_blocks:
            raise ValueError(
                f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
                f"num_blocks ({num_blocks})"
            )
        if isinstance(attn_branch_drop_rate, float):
            attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
        if len(attn_branch_drop_rate) != num_blocks:
            raise ValueError(
                f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
                f"should be equal to num_blocks ({num_blocks})"
            )
        self.encoders = repeat(
            num_blocks,
            lambda lnum: BranchformerEncoderLayer(
                output_size,
                encoder_selfattn_layer(*encoder_selfattn_layer_args)
                if use_attn
                else None,
                cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
                dropout_rate,
                merge_method,
                cgmlp_weight[lnum],
                attn_branch_drop_rate[lnum],
                stochastic_depth_rate[lnum],
            ),
        )
        self.after_norm = LayerNorm(output_size)
    def output_size(self) -> int:
        return self._output_size
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        elif self.embed is not None:
            xs_pad = self.embed(xs_pad)
        xs_pad, masks = self.encoders(xs_pad, masks)
        if isinstance(xs_pad, tuple):
            xs_pad = xs_pad[0]
        xs_pad = self.after_norm(xs_pad)
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
funasr/models/encoder/e_branchformer_encoder.py
New file
@@ -0,0 +1,465 @@
# Copyright 2022 Kwangyoun Kim (ASAPP inc.)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""E-Branchformer encoder definition.
Reference:
    Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan,
    Prashant Sridhar, Kyu J. Han, Shinji Watanabe,
    "E-Branchformer: Branchformer with Enhanced merging
    for speech recognition," in SLT 2022.
"""
import logging
from typing import List, Optional, Tuple
import torch
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.cgmlp import ConvolutionalGatingMLP
from funasr.modules.fastformer import FastSelfAttention
from funasr.modules.nets_utils import get_activation, make_pad_mask
from funasr.modules.attention import (  # noqa: H301
    LegacyRelPositionMultiHeadedAttention,
    MultiHeadedAttention,
    RelPositionMultiHeadedAttention,
)
from funasr.modules.embedding import (  # noqa: H301
    LegacyRelPositionalEncoding,
    PositionalEncoding,
    RelPositionalEncoding,
    ScaledPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.positionwise_feed_forward import (
    PositionwiseFeedForward,
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import (
    Conv2dSubsampling,
    Conv2dSubsampling2,
    Conv2dSubsampling6,
    Conv2dSubsampling8,
    TooShortUttError,
    check_short_utt,
)
class EBranchformerEncoderLayer(torch.nn.Module):
    """E-Branchformer encoder layer module.
    Args:
        size (int): model dimension
        attn: standard self-attention or efficient attention
        cgmlp: ConvolutionalGatingMLP
        feed_forward: feed-forward module, optional
        feed_forward: macaron-style feed-forward module, optional
        dropout_rate (float): dropout probability
        merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
    """
    def __init__(
        self,
        size: int,
        attn: torch.nn.Module,
        cgmlp: torch.nn.Module,
        feed_forward: Optional[torch.nn.Module],
        feed_forward_macaron: Optional[torch.nn.Module],
        dropout_rate: float,
        merge_conv_kernel: int = 3,
    ):
        super().__init__()
        self.size = size
        self.attn = attn
        self.cgmlp = cgmlp
        self.feed_forward = feed_forward
        self.feed_forward_macaron = feed_forward_macaron
        self.ff_scale = 1.0
        if self.feed_forward is not None:
            self.norm_ff = LayerNorm(size)
        if self.feed_forward_macaron is not None:
            self.ff_scale = 0.5
            self.norm_ff_macaron = LayerNorm(size)
        self.norm_mha = LayerNorm(size)  # for the MHA module
        self.norm_mlp = LayerNorm(size)  # for the MLP module
        self.norm_final = LayerNorm(size)  # for the final output of the block
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.depthwise_conv_fusion = torch.nn.Conv1d(
            size + size,
            size + size,
            kernel_size=merge_conv_kernel,
            stride=1,
            padding=(merge_conv_kernel - 1) // 2,
            groups=size + size,
            bias=True,
        )
        self.merge_proj = torch.nn.Linear(size + size, size)
    def forward(self, x_input, mask, cache=None):
        """Compute encoded features.
        Args:
            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
                - w/o pos emb: Tensor (#batch, time, size).
            mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
        Returns:
            torch.Tensor: Output tensor (#batch, time, size).
            torch.Tensor: Mask tensor (#batch, time).
        """
        if cache is not None:
            raise NotImplementedError("cache is not None, which is not tested")
        if isinstance(x_input, tuple):
            x, pos_emb = x_input[0], x_input[1]
        else:
            x, pos_emb = x_input, None
        if self.feed_forward_macaron is not None:
            residual = x
            x = self.norm_ff_macaron(x)
            x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
        # Two branches
        x1 = x
        x2 = x
        # Branch 1: multi-headed attention module
        x1 = self.norm_mha(x1)
        if isinstance(self.attn, FastSelfAttention):
            x_att = self.attn(x1, mask)
        else:
            if pos_emb is not None:
                x_att = self.attn(x1, x1, x1, pos_emb, mask)
            else:
                x_att = self.attn(x1, x1, x1, mask)
        x1 = self.dropout(x_att)
        # Branch 2: convolutional gating mlp
        x2 = self.norm_mlp(x2)
        if pos_emb is not None:
            x2 = (x2, pos_emb)
        x2 = self.cgmlp(x2, mask)
        if isinstance(x2, tuple):
            x2 = x2[0]
        x2 = self.dropout(x2)
        # Merge two branches
        x_concat = torch.cat([x1, x2], dim=-1)
        x_tmp = x_concat.transpose(1, 2)
        x_tmp = self.depthwise_conv_fusion(x_tmp)
        x_tmp = x_tmp.transpose(1, 2)
        x = x + self.dropout(self.merge_proj(x_concat + x_tmp))
        if self.feed_forward is not None:
            # feed forward module
            residual = x
            x = self.norm_ff(x)
            x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
        x = self.norm_final(x)
        if pos_emb is not None:
            return (x, pos_emb), mask
        return x, mask
class EBranchformerEncoder(AbsEncoder):
    """E-Branchformer encoder module."""
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        attention_layer_type: str = "rel_selfattn",
        pos_enc_layer_type: str = "rel_pos",
        rel_pos_type: str = "latest",
        cgmlp_linear_units: int = 2048,
        cgmlp_conv_kernel: int = 31,
        use_linear_after_conv: bool = False,
        gate_activation: str = "identity",
        num_blocks: int = 12,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "conv2d",
        zero_triu: bool = False,
        padding_idx: int = -1,
        layer_drop_rate: float = 0.0,
        max_pos_emb_len: int = 5000,
        use_ffn: bool = False,
        macaron_ffn: bool = False,
        ffn_activation_type: str = "swish",
        linear_units: int = 2048,
        positionwise_layer_type: str = "linear",
        merge_conv_kernel: int = 3,
        interctc_layer_idx=None,
        interctc_use_conditioning: bool = False,
    ):
        super().__init__()
        self._output_size = output_size
        if rel_pos_type == "legacy":
            if pos_enc_layer_type == "rel_pos":
                pos_enc_layer_type = "legacy_rel_pos"
            if attention_layer_type == "rel_selfattn":
                attention_layer_type = "legacy_rel_selfattn"
        elif rel_pos_type == "latest":
            assert attention_layer_type != "legacy_rel_selfattn"
            assert pos_enc_layer_type != "legacy_rel_pos"
        else:
            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
        if pos_enc_layer_type == "abs_pos":
            pos_enc_class = PositionalEncoding
        elif pos_enc_layer_type == "scaled_abs_pos":
            pos_enc_class = ScaledPositionalEncoding
        elif pos_enc_layer_type == "rel_pos":
            assert attention_layer_type == "rel_selfattn"
            pos_enc_class = RelPositionalEncoding
        elif pos_enc_layer_type == "legacy_rel_pos":
            assert attention_layer_type == "legacy_rel_selfattn"
            pos_enc_class = LegacyRelPositionalEncoding
            logging.warning(
                "Using legacy_rel_pos and it will be deprecated in the future."
            )
        else:
            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
        if input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(input_size, output_size),
                torch.nn.LayerNorm(output_size),
                torch.nn.Dropout(dropout_rate),
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif input_layer == "conv2d":
            self.embed = Conv2dSubsampling(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif input_layer == "conv2d2":
            self.embed = Conv2dSubsampling2(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif input_layer == "conv2d6":
            self.embed = Conv2dSubsampling6(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif input_layer == "conv2d8":
            self.embed = Conv2dSubsampling8(
                input_size,
                output_size,
                dropout_rate,
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif isinstance(input_layer, torch.nn.Module):
            self.embed = torch.nn.Sequential(
                input_layer,
                pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
            )
        elif input_layer is None:
            if input_size == output_size:
                self.embed = None
            else:
                self.embed = torch.nn.Linear(input_size, output_size)
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        activation = get_activation(ffn_activation_type)
        if positionwise_layer_type == "linear":
            positionwise_layer = PositionwiseFeedForward
            positionwise_layer_args = (
                output_size,
                linear_units,
                dropout_rate,
                activation,
            )
        elif positionwise_layer_type is None:
            logging.warning("no macaron ffn")
        else:
            raise ValueError("Support only linear.")
        if attention_layer_type == "selfattn":
            encoder_selfattn_layer = MultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
        elif attention_layer_type == "legacy_rel_selfattn":
            assert pos_enc_layer_type == "legacy_rel_pos"
            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
            )
            logging.warning(
                "Using legacy_rel_selfattn and it will be deprecated in the future."
            )
        elif attention_layer_type == "rel_selfattn":
            assert pos_enc_layer_type == "rel_pos"
            encoder_selfattn_layer = RelPositionMultiHeadedAttention
            encoder_selfattn_layer_args = (
                attention_heads,
                output_size,
                attention_dropout_rate,
                zero_triu,
            )
        elif attention_layer_type == "fast_selfattn":
            assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
            encoder_selfattn_layer = FastSelfAttention
            encoder_selfattn_layer_args = (
                output_size,
                attention_heads,
                attention_dropout_rate,
            )
        else:
            raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
        cgmlp_layer = ConvolutionalGatingMLP
        cgmlp_layer_args = (
            output_size,
            cgmlp_linear_units,
            cgmlp_conv_kernel,
            dropout_rate,
            use_linear_after_conv,
            gate_activation,
        )
        self.encoders = repeat(
            num_blocks,
            lambda lnum: EBranchformerEncoderLayer(
                output_size,
                encoder_selfattn_layer(*encoder_selfattn_layer_args),
                cgmlp_layer(*cgmlp_layer_args),
                positionwise_layer(*positionwise_layer_args) if use_ffn else None,
                positionwise_layer(*positionwise_layer_args)
                if use_ffn and macaron_ffn
                else None,
                dropout_rate,
                merge_conv_kernel,
            ),
            layer_drop_rate,
        )
        self.after_norm = LayerNorm(output_size)
        if interctc_layer_idx is None:
            interctc_layer_idx = []
        self.interctc_layer_idx = interctc_layer_idx
        if len(interctc_layer_idx) > 0:
            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
        self.interctc_use_conditioning = interctc_use_conditioning
        self.conditioning_layer = None
    def output_size(self) -> int:
        return self._output_size
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
        ctc: CTC = None,
        max_layer: int = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
            ctc (CTC): Intermediate CTC module.
            max_layer (int): Layer depth below which InterCTC is applied.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
            xs_pad, masks = self.embed(xs_pad, masks)
        elif self.embed is not None:
            xs_pad = self.embed(xs_pad)
        intermediate_outs = []
        if len(self.interctc_layer_idx) == 0:
            if max_layer is not None and 0 <= max_layer < len(self.encoders):
                for layer_idx, encoder_layer in enumerate(self.encoders):
                    xs_pad, masks = encoder_layer(xs_pad, masks)
                    if layer_idx >= max_layer:
                        break
            else:
                xs_pad, masks = self.encoders(xs_pad, masks)
        else:
            for layer_idx, encoder_layer in enumerate(self.encoders):
                xs_pad, masks = encoder_layer(xs_pad, masks)
                if layer_idx + 1 in self.interctc_layer_idx:
                    encoder_out = xs_pad
                    if isinstance(encoder_out, tuple):
                        encoder_out = encoder_out[0]
                    intermediate_outs.append((layer_idx + 1, encoder_out))
                    if self.interctc_use_conditioning:
                        ctc_out = ctc.softmax(encoder_out)
                        if isinstance(xs_pad, tuple):
                            xs_pad = list(xs_pad)
                            xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
                            xs_pad = tuple(xs_pad)
                        else:
                            xs_pad = xs_pad + self.conditioning_layer(ctc_out)
        if isinstance(xs_pad, tuple):
            xs_pad = xs_pad[0]
        xs_pad = self.after_norm(xs_pad)
        olens = masks.squeeze(1).sum(1)
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
funasr/modules/cgmlp.py
New file
@@ -0,0 +1,124 @@
"""MLP with convolutional gating (cgMLP) definition.
References:
    https://openreview.net/forum?id=RA-zVvZLYIy
    https://arxiv.org/abs/2105.08050
"""
import torch
from funasr.modules.nets_utils import get_activation
from funasr.modules.layer_norm import LayerNorm
class ConvolutionalSpatialGatingUnit(torch.nn.Module):
    """Convolutional Spatial Gating Unit (CSGU)."""
    def __init__(
        self,
        size: int,
        kernel_size: int,
        dropout_rate: float,
        use_linear_after_conv: bool,
        gate_activation: str,
    ):
        super().__init__()
        n_channels = size // 2  # split input channels
        self.norm = LayerNorm(n_channels)
        self.conv = torch.nn.Conv1d(
            n_channels,
            n_channels,
            kernel_size,
            1,
            (kernel_size - 1) // 2,
            groups=n_channels,
        )
        if use_linear_after_conv:
            self.linear = torch.nn.Linear(n_channels, n_channels)
        else:
            self.linear = None
        if gate_activation == "identity":
            self.act = torch.nn.Identity()
        else:
            self.act = get_activation(gate_activation)
        self.dropout = torch.nn.Dropout(dropout_rate)
    def espnet_initialization_fn(self):
        torch.nn.init.normal_(self.conv.weight, std=1e-6)
        torch.nn.init.ones_(self.conv.bias)
        if self.linear is not None:
            torch.nn.init.normal_(self.linear.weight, std=1e-6)
            torch.nn.init.ones_(self.linear.bias)
    def forward(self, x, gate_add=None):
        """Forward method
        Args:
            x (torch.Tensor): (N, T, D)
            gate_add (torch.Tensor): (N, T, D/2)
        Returns:
            out (torch.Tensor): (N, T, D/2)
        """
        x_r, x_g = x.chunk(2, dim=-1)
        x_g = self.norm(x_g)  # (N, T, D/2)
        x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2)  # (N, T, D/2)
        if self.linear is not None:
            x_g = self.linear(x_g)
        if gate_add is not None:
            x_g = x_g + gate_add
        x_g = self.act(x_g)
        out = x_r * x_g  # (N, T, D/2)
        out = self.dropout(out)
        return out
class ConvolutionalGatingMLP(torch.nn.Module):
    """Convolutional Gating MLP (cgMLP)."""
    def __init__(
        self,
        size: int,
        linear_units: int,
        kernel_size: int,
        dropout_rate: float,
        use_linear_after_conv: bool,
        gate_activation: str,
    ):
        super().__init__()
        self.channel_proj1 = torch.nn.Sequential(
            torch.nn.Linear(size, linear_units), torch.nn.GELU()
        )
        self.csgu = ConvolutionalSpatialGatingUnit(
            size=linear_units,
            kernel_size=kernel_size,
            dropout_rate=dropout_rate,
            use_linear_after_conv=use_linear_after_conv,
            gate_activation=gate_activation,
        )
        self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
    def forward(self, x, mask):
        if isinstance(x, tuple):
            xs_pad, pos_emb = x
        else:
            xs_pad, pos_emb = x, None
        xs_pad = self.channel_proj1(xs_pad)  # size -> linear_units
        xs_pad = self.csgu(xs_pad)  # linear_units -> linear_units/2
        xs_pad = self.channel_proj2(xs_pad)  # linear_units/2 -> size
        if pos_emb is not None:
            out = (xs_pad, pos_emb)
        else:
            out = xs_pad
        return out
funasr/modules/fastformer.py
New file
@@ -0,0 +1,153 @@
"""Fastformer attention definition.
Reference:
    Wu et al., "Fastformer: Additive Attention Can Be All You Need"
    https://arxiv.org/abs/2108.09084
    https://github.com/wuch15/Fastformer
"""
import numpy
import torch
class FastSelfAttention(torch.nn.Module):
    """Fast self-attention used in Fastformer."""
    def __init__(
        self,
        size,
        attention_heads,
        dropout_rate,
    ):
        super().__init__()
        if size % attention_heads != 0:
            raise ValueError(
                f"Hidden size ({size}) is not an integer multiple "
                f"of attention heads ({attention_heads})"
            )
        self.attention_head_size = size // attention_heads
        self.num_attention_heads = attention_heads
        self.query = torch.nn.Linear(size, size)
        self.query_att = torch.nn.Linear(size, attention_heads)
        self.key = torch.nn.Linear(size, size)
        self.key_att = torch.nn.Linear(size, attention_heads)
        self.transform = torch.nn.Linear(size, size)
        self.dropout = torch.nn.Dropout(dropout_rate)
    def espnet_initialization_fn(self):
        self.apply(self.init_weights)
    def init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, torch.nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
    def transpose_for_scores(self, x):
        """Reshape and transpose to compute scores.
        Args:
            x: (batch, time, size = n_heads * attn_dim)
        Returns:
            (batch, n_heads, time, attn_dim)
        """
        new_x_shape = x.shape[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        return x.reshape(*new_x_shape).transpose(1, 2)
    def forward(self, xs_pad, mask):
        """Forward method.
        Args:
            xs_pad: (batch, time, size = n_heads * attn_dim)
            mask: (batch, 1, time), nonpadding is 1, padding is 0
        Returns:
            torch.Tensor: (batch, time, size)
        """
        batch_size, seq_len, _ = xs_pad.shape
        mixed_query_layer = self.query(xs_pad)  # (batch, time, size)
        mixed_key_layer = self.key(xs_pad)  # (batch, time, size)
        if mask is not None:
            mask = mask.eq(0)  # padding is 1, nonpadding is 0
        # (batch, n_heads, time)
        query_for_score = (
            self.query_att(mixed_query_layer).transpose(1, 2)
            / self.attention_head_size**0.5
        )
        if mask is not None:
            min_value = float(
                numpy.finfo(
                    torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
                ).min
            )
            query_for_score = query_for_score.masked_fill(mask, min_value)
            query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
        else:
            query_weight = torch.softmax(query_for_score, dim=-1)
        query_weight = query_weight.unsqueeze(2)  # (batch, n_heads, 1, time)
        query_layer = self.transpose_for_scores(
            mixed_query_layer
        )  # (batch, n_heads, time, attn_dim)
        pooled_query = (
            torch.matmul(query_weight, query_layer)
            .transpose(1, 2)
            .reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
        )  # (batch, 1, size = n_heads * attn_dim)
        pooled_query = self.dropout(pooled_query)
        pooled_query_repeat = pooled_query.repeat(1, seq_len, 1)  # (batch, time, size)
        mixed_query_key_layer = (
            mixed_key_layer * pooled_query_repeat
        )  # (batch, time, size)
        # (batch, n_heads, time)
        query_key_score = (
            self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
        ).transpose(1, 2)
        if mask is not None:
            min_value = float(
                numpy.finfo(
                    torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
                ).min
            )
            query_key_score = query_key_score.masked_fill(mask, min_value)
            query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
                mask, 0.0
            )
        else:
            query_key_weight = torch.softmax(query_key_score, dim=-1)
        query_key_weight = query_key_weight.unsqueeze(2)  # (batch, n_heads, 1, time)
        key_layer = self.transpose_for_scores(
            mixed_query_key_layer
        )  # (batch, n_heads, time, attn_dim)
        pooled_key = torch.matmul(
            query_key_weight, key_layer
        )  # (batch, n_heads, 1, attn_dim)
        pooled_key = self.dropout(pooled_key)
        # NOTE: value = query, due to param sharing
        weighted_value = (pooled_key * query_layer).transpose(
            1, 2
        )  # (batch, time, n_heads, attn_dim)
        weighted_value = weighted_value.reshape(
            weighted_value.shape[:-2]
            + (self.num_attention_heads * self.attention_head_size,)
        )  # (batch, time, size)
        weighted_value = (
            self.dropout(self.transform(weighted_value)) + mixed_query_layer
        )
        return weighted_value
funasr/modules/repeat.py
@@ -14,25 +14,38 @@
class MultiSequential(torch.nn.Sequential):
    """Multi-input multi-output torch.nn.Sequential."""
    def __init__(self, *args, layer_drop_rate=0.0):
        """Initialize MultiSequential with layer_drop.
        Args:
            layer_drop_rate (float): Probability of dropping out each fn (layer).
        """
        super(MultiSequential, self).__init__(*args)
        self.layer_drop_rate = layer_drop_rate
    def forward(self, *args):
        """Repeat."""
        for m in self:
            args = m(*args)
        _probs = torch.empty(len(self)).uniform_()
        for idx, m in enumerate(self):
            if not self.training or (_probs[idx] >= self.layer_drop_rate):
                args = m(*args)
        return args
def repeat(N, fn):
def repeat(N, fn, layer_drop_rate=0.0):
    """Repeat module N times.
    Args:
        N (int): Number of repeat time.
        fn (Callable): Function to generate module.
        layer_drop_rate (float): Probability of dropping out each fn (layer).
    Returns:
        MultiSequential: Repeated model instance.
    """
    return MultiSequential(*[fn(n) for n in range(N)])
    return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)
class MultiBlocks(torch.nn.Module):
funasr/runtime/html5/demo.gif
Binary files differ
funasr/runtime/html5/readme.md
@@ -1,3 +1,5 @@
([简体中文](./readme_zh.md)|English)
# Html5 server for asr service
## Requirement
funasr/runtime/html5/readme_cn.md
File was deleted
funasr/runtime/html5/readme_zh.md
New file
@@ -0,0 +1,93 @@
(简体中文|[English](./readme.md))
# 语音识别服务Html5客户端访问界面
服务端部署采用websocket协议,客户端可以支持html5网页访问,支持麦克风输入与文件输入,可以通过如下2种方式访问:
- 方式一:
   html客户端直连,手动下载客户端([点击此处](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/html5/static))至本地,打开`index.html`网页,输入wss地址与端口号
- 方式二:
   html5服务端,自动下载客户端至本地,支持手机等端上访问
## 语音识别服务启动
支持python版本与c++版本服务部署,其中
- python版本
  直接部署python pipeline,支持流式实时语音识别模型,离线语音识别模型,流式离线一体化纠错模型,输出带标点文字。单个server,支持单个client。
- c++版本
  funasr-runtime-sdk,支持一键部署,0.1.0版本,支持离线文件转写。单个server,支持上百路client请求。
### python版本服务启动
#### 安装依赖环境
```shell
pip3 install -U modelscope funasr flask
# 中国大陆用户,如果遇到网络问题,可以通过下面指令安装:
# pip3 install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
git clone https://github.com/alibaba/FunASR.git && cd FunASR
```
#### 启动ASR服务
#### wss方式
```shell
cd funasr/runtime/python/websocket
python funasr_wss_server.py --port 10095
```
详细参数配置与解析([点击此处](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket))
#### html5服务(可选)
如果需要使用上面所说的客户端方式二,进行访问,可以启动html5服务
```shell
h5Server.py [-h] [--host HOST] [--port PORT] [--certfile CERTFILE] [--keyfile KEYFILE]
```
例子如下,需要注意ip地址,如果从其他设备访问需求(例如手机端),需要将ip地址设为真实公网ip
```shell
cd funasr/runtime/html5
python h5Server.py --host 0.0.0.0 --port 1337
```
启动后,在浏览器中输入([https://127.0.0.1:1337/static/index.html](https://127.0.0.1:1337/static/index.html))即可访问
### c++ 版本服务启动
由于c++依赖环境较多,建议采用docker部署,支持一键启动服务
```shell
curl -O https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/shell/funasr-runtime-deploy-offline-cpu-zh.sh;
sudo bash funasr-runtime-deploy-offline-cpu-zh.sh install --workspace /root/funasr-runtime-resources
```
详细参数配置与解析([点击此处](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/docs/SDK_tutorial_zh.md))
## 客户端测试
### 方式一
html客户端直连,手动下载客户端([点击此处](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/html5/static))至本地,打开`index.html`网页,输入wss地址与端口号即可使用
### 方式二
html5服务端,自动下载客户端至本地,支持手机等端上访问,ip地址需要与html5 server保持一致,如果是本地机器,可以用127.0.0.1
```shell
https://127.0.0.1:1337/static/index.html
```
输入wss地址与端口号即可使用
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the html5 demo.
funasr/runtime/onnxruntime/third_party/download_ffmpeg.sh
New file
@@ -0,0 +1,5 @@
wget https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2023-07-09-12-50/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
# 国内可以使用下述方式
# wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
# tar -xvf ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
funasr/runtime/onnxruntime/third_party/download_onnxruntime.sh
New file
@@ -0,0 +1,5 @@
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
# here we get a copy of onnxruntime for linux 64
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
funasr/runtime/python/websocket/funasr_wss_client.py
@@ -204,6 +204,7 @@
        
            meg = await websocket.recv()
            meg = json.loads(meg)
            print(meg)
            wav_name = meg.get("wav_name", "demo")
            text = meg["text"]
funasr/runtime/python/websocket/funasr_wss_server.py
@@ -240,7 +240,8 @@
                                                         param_dict=websocket.param_dict_punc)
                    # print("offline", rec_result)
                if 'text' in rec_result:
                    message = json.dumps({"mode": websocket.mode, "text": rec_result["text"], "wav_name": websocket.wav_name})
                    mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
                    message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name})
                    await websocket.send(message)
@@ -256,7 +257,8 @@
        if "text" in rec_result:
            if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
                # print("online", rec_result)
                message = json.dumps({"mode": websocket.mode, "text": rec_result["text"], "wav_name": websocket.wav_name})
                mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
                message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name})
                await websocket.send(message)
if len(args.certfile)>0:
funasr/runtime/websocket/readme.md
@@ -1,3 +1,5 @@
([简体中文](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/websocket/readme_zh.md)|English)
# Service with websocket-cpp
## Export the model
funasr/runtime/websocket/readme_zh.md
New file
@@ -0,0 +1,190 @@
(简体中文|[English](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/websocket/readme.md))
# 采用websocket协议的c++部署方案
## 快速上手
### 镜像启动
通过下述命令拉取并启动FunASR runtime-SDK的docker镜像:
```shell
sudo docker pull registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.1.0
sudo docker run -p 10095:10095 -it --privileged=true -v /root:/workspace/models registry.cn-hangzhou.aliyuncs.com/funasr_repo/funasr:funasr-runtime-sdk-cpu-0.1.0
```
如果您没有安装docker,可参考[Docker安装](#Docker安装)
### 服务端启动
docker启动之后,启动 funasr-wss-server服务程序:
```shell
cd FunASR/funasr/runtime
./run_server.sh \
  --download-model-dir /workspace/models \
  --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
  --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx  \
  --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
```
服务端详细参数介绍可参考[服务端参数介绍](#服务端参数介绍)
### 客户端测试与使用
下载客户端测试工具目录samples
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/sample/funasr_samples.tar.gz
```
我们以Python语言客户端为例,进行说明,支持多种音频格式输入(.wav, .pcm, .mp3等),也支持视频输入(.mp4等),以及多文件列表wav.scp输入,其他版本客户端请参考文档([点击此处](#客户端用法详解)),定制服务部署请参考[如何定制服务部署](#如何定制服务部署)
```shell
python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.wav"
```
------------------
## 操作步骤详解
### 依赖库下载
#### Download onnxruntime
```shell
bash third_party/download_onnxruntime.sh
```
#### Download ffmpeg
```shell
bash third_party/download_ffmpeg.sh
```
#### Install openblas and openssl
```shell
sudo apt-get install libopenblas-dev libssl-dev #ubuntu
# sudo yum -y install openblas-devel openssl-devel #centos
```
### 编译
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/funasr/runtime/websocket
mkdir build && cd build
cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
make
```
### 启动服务部署
#### 从modelscope中模型启动示例
```shell
./funasr-wss-server  \
  --download-model-dir /workspace/models \
  --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
  --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
  --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
```
注意:上面示例中,`model-dir`,`vad-dir`,`punc-dir`为模型在modelscope中模型名字,直接从modelscope下载模型并且导出量化后的onnx。如果需要从本地启动,需要改成本地绝对路径。
#### 从本地模型启动示例
##### 导出模型
```shell
python -m funasr.export.export_model \
--export-dir ./export \
--type onnx \
--quantize True \
--model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch \
--model-name damo/speech_fsmn_vad_zh-cn-16k-common-pytorch \
--model-name damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch
```
导出过程详细介绍([点击此处](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
##### 启动服务
```shell
./funasr-wss-server  \
  --download-model-dir /workspace/models \
  --model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
  --vad-dir ./exportdamo/speech_fsmn_vad_zh-cn-16k-common-onnx \
  --punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
```
#### 命令参数介绍:
```text
--download-model-dir 模型下载地址,通过设置model ID从Modelscope下载模型。如果从本地模型启动,可以不设置。
--model-dir  modelscope 中 ASR model ID,或者本地模型绝对路径
--quantize  True为量化ASR模型,False为非量化ASR模型,默认是True
--vad-dir  modelscope 中 VAD model ID,或者本地模型绝对路径
--vad-quant   True为量化VAD模型,False为非量化VAD模型,默认是True
--punc-dir  modelscope 中 标点 model ID,或者本地模型绝对路径
--punc-quant   True为量化PUNC模型,False为非量化PUNC模型,默认是True
--port  服务端监听的端口号,默认为 10095
--decoder-thread-num  服务端启动的推理线程数,默认为 8
--io-thread-num  服务端启动的IO线程数,默认为 1
--certfile  ssl的证书文件,默认为:../../../ssl_key/server.crt
--keyfile   ssl的密钥文件,默认为:../../../ssl_key/server.key
```
### 客户端用法详解
下载客户端测试工具目录samples
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/sample/funasr_samples.tar.gz
```
在服务器上完成FunASR服务部署以后,可以通过如下的步骤来测试和使用离线文件转写服务。
目前分别支持以下几种编程语言客户端
- [Python](#python-client)
- [CPP](#cpp-client)
- [html网页版本](#Html网页版)
- [Java](#Java-client)
#### python-client
若想直接运行client进行测试,可参考如下简易说明,以python版本为例:
```shell
python3 wss_client_asr.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.wav" --output_dir "./results"
```
命令参数说明:
```text
--host 为FunASR runtime-SDK服务部署机器ip,默认为本机ip(127.0.0.1),如果client与服务不在同一台服务器,需要改为部署机器ip
--port 10095 部署端口号
--mode offline表示离线文件转写
--audio_in 需要进行转写的音频文件,支持文件路径,文件列表wav.scp
--output_dir 识别结果保存路径
```
### cpp-client
进入samples/cpp目录后,可以用cpp进行测试,指令如下:
```shell
./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
```
命令参数说明:
```text
--server-ip 为FunASR runtime-SDK服务部署机器ip,默认为本机ip(127.0.0.1),如果client与服务不在同一台服务器,需要改为部署机器ip
--port 10095 部署端口号
--wav-path 需要进行转写的音频文件,支持文件路径
```
### Html网页版
在浏览器中打开 html/static/index.html,即可出现如下页面,支持麦克风输入与文件上传,直接进行体验
<img src="images/html.png"  width="900"/>
### Java-client
```shell
FunasrWsClient --host localhost --port 10095 --audio_in ./asr_example.wav --mode offline
```
详细可以参考文档([点击此处](../java/readme.md))
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/add-offline-websocket-srv/funasr/runtime/websocket) for contributing the websocket(cpp-api).
funasr/tasks/asr.py
@@ -1538,7 +1538,6 @@
        Return:
            model: ASR BAT model.
        """
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
funasr/train/trainer.py
@@ -369,7 +369,7 @@
                            ],
                            "scaler": scaler.state_dict() if scaler is not None else None,
                            "ema_model": model.encoder.ema.model.state_dict()
                            if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
                            if hasattr(model, "encoder") and hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
                        },
                        buffer,
                    )
funasr/utils/timestamp_tools.py
@@ -1,14 +1,10 @@
from itertools import zip_longest
import torch
import copy
import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
import edit_distance
from itertools import zip_longest
def ts_prediction_lfr6_standard(us_alphas, 
@@ -36,7 +32,14 @@
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    if num_peak != len(char_list) + 1:
        logging.warning("length mismatch, result might be incorrect.")
        logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
    if num_peak > len(char_list) + 1:
        fire_place = fire_place[:len(char_list) - 1]
    elif num_peak < len(char_list) + 1:
        char_list = char_list[:num_peak + 1]
    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        # char_list.insert(0, '<sil>')
funasr/version.txt
@@ -1 +1 @@
0.6.9
0.7.0
setup.py
@@ -23,6 +23,7 @@
        "nltk>=3.4.5",
        # ASR
        "sentencepiece",
        "jieba",
        # TTS
        "pypinyin>=0.44.0",
        "espnet_tts_frontend",
@@ -122,4 +123,4 @@
        "License :: OSI Approved :: Apache Software License",
        "Topic :: Software Development :: Libraries :: Python Modules",
    ],
)
)