雾聪
2023-06-01 dbffb474152b0a445bc093f6c0537e4d5d0ab5b4
Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
14个文件已修改
13个文件已添加
1个文件已删除
1308 ■■■■ 已修改文件
LICENSE 73 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
MODEL_LICENSE 73 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/academic_recipe/asr_recipe.md 85 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/model_zoo/modelscope_models.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/TEMPLATE/README.md 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/README.md 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py 37 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/README.md 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py 37 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_asr_model.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/sanm_decoder.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_paraformer.py 485 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/sanm_encoder.py 55 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/wss_srv_asr.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
LICENSE
File was deleted
MODEL_LICENSE
New file
@@ -0,0 +1,73 @@
FunASR 模型开源协议
版本号:1.0
版权所有 (C) [2023-2028] [阿里巴巴集团]。保留所有权利。
感谢您选择 FunASR 开源模型。FunASR 开源模型包含一系列免费且开源的工业模型,让大家可以使用、修改、分享和学习该模型。
为了保证更好的社区合作,我们制定了以下协议,希望您仔细阅读并遵守本协议。
1 定义
本协议中,[FunASR 软件]指 FunASR 开源模型及其衍生品,包括 Finetune 后的模型;[您]指使用、修改、分享和学习[FunASR 软件]的个人或组织。
2 许可和限制
2.1 许可
您可以在遵守本协议的前提下,自由地使用、复制、修改和分享[FunASR 软件]。
2.2 限制
您在使用、复制、修改和分享[FunASR 软件]时,必须注明出处以及作者信息。并且,将[FunASR 软件]上传至其他第三方平台以供下载,需要获得额外许可,可通过官方邮件(funasr@list.alibaba-inc.com)进行申请(免费)。
3 责任和风险承担
[FunASR 软件]仅作为参考和学习使用,不对您使用或修改[FunASR 软件]造成的任何直接或间接损失承担任何责任。您对[FunASR 软件]的使用和修改应该自行承担风险。
4 终止
如果您违反本协议的任何条款,您的许可将自动终止,您必须停止使用、复制、修改和分享[FunASR 软件]。
5 修订
本协议可能会不时更新和修订。修订后的协议将在[FunASR 软件]官方仓库发布,并自动生效。如果您继续使用、复制、修改和分享[FunASR 软件],即表示您同意修订后的协议。
6 其他规定
本协议受到[国家/地区] 的法律管辖。如果任何条款被裁定为不合法、无效或无法执行,则该条款应被视为从本协议中删除,而其余条款应继续有效并具有约束力。
如果您对本协议有任何问题或意见,请联系我们。
版权所有© [2023-2028] [阿里巴巴集团]。保留所有权利。
FunASR Model Open Source License
Version 1.0
Copyright (C) [2023-2028] Alibaba Group. All rights reserved.
Thank you for choosing the FunASR open source models. The FunASR open source models contain a series of open-source models that allow everyone to use, modify, share, and learn from it.
To ensure better community collaboration, we have developed the following agreement and hope that you carefully read and abide by it.
1 Definitions
In this agreement, [FunASR software] refers to the FunASR open source model, and its derivatives, including fine-tuned models. [You] refer to individuals or organizations who use, modify, share, and learn from [FunASR software].
2 License and Restrictions
2.1 License
You are free to use, copy, modify, and share [FunASR software] under the conditions of this agreement.
2.2 Restrictions
You should indicate the code and model source and author information when using, copying, modifying and sharing [FunASR software]. To upload the [FunASR software] to other third-party platforms for download, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com).
3 Responsibility and Risk
[FunASR software] is for reference and learning purposes only and is not responsible for any direct or indirect losses caused by your use or modification of [FunASR software]. You should take responsibility and risks for your use and modification of [FunASR software].
4 Termination
If you violate any terms of this agreement, your license will be automatically terminated, and you must stop using, copying, modifying, and sharing [FunASR software].
5 Revision
This agreement may be updated and revised from time to time. The revised agreement will be published in the FunASR official repository and automatically take effect. If you continue to use, copy, modify, and share [FunASR software], it means you agree to the revised agreement.
6 Other Provisions
This agreement is subject to the laws of [Country/Region]. If any provisions are found to be illegal, invalid, or unenforceable, they shall be deemed deleted from this agreement, and the remaining provisions shall remain valid and binding.
If you have any questions or comments about this agreement, please contact us.
Copyright (c) [2023-2028] Alibaba Group. All rights reserved.
README.md
@@ -88,6 +88,7 @@
## License
This project is licensed under the [The MIT License](https://opensource.org/licenses/MIT). FunASR also contains various third-party components and some code modified from other repos under other open source licenses.
The use of pretraining model is subject to [model licencs](./MODEL_LICENSE)
## Citations
docs/academic_recipe/asr_recipe.md
@@ -37,7 +37,7 @@
```
Here is an example of loss:
<img src="images/loss.png" width="200"/>
<img src="./academic_recipe/images/loss.png" width="200"/>
The inference results are saved in `${exp_dir}/exp/${model_dir}/decode_asr_*/$dset`. The main two files are `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text, like follows:
```text
@@ -106,18 +106,37 @@
<unk>
```
* `<blank>`: indicates the blank token for CTC, must be in the first line
* `<s>`: indicates the start-of-sentence token, must be in the second line
* `</s>`: indicates the end-of-sentence token, must be in the third line
* `<unk>`: indicates the out-of-vocabulary token, must be in the last line
There are four tokens must be specified:
* `<blank>`: (required), indicates the blank token for CTC, must be in the first line
* `<s>`: (required), indicates the start-of-sentence token, must be in the second line
* `</s>`: (required), indicates the end-of-sentence token, must be in the third line
* `<unk>`: (required), indicates the out-of-vocabulary token, must be in the last line
### Stage 3: LM Training
### Stage 4: ASR Training
This stage achieves the training of the specified model. To start training, users should manually set `exp_dir` to specify the path for saving experimental results. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding. FunASR implements `train.py` for training different models and users can configure the following parameters if necessary.
This stage achieves the training of the specified model. To start training, users should manually set `exp_dir` to specify the path for saving experimental results. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding. FunASR implements `train.py` for training different models and users can configure the following parameters if necessary. The training command is as follows:
```sh
train.py \
    --task_name asr \
    --use_preprocessor true \
    --token_list $token_list \
    --data_dir ${feats_dir}/data \
    --train_set ${train_set} \
    --valid_set ${valid_set} \
    --data_file_names "wav.scp,text" \
    --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
    --speed_perturb ${speed_perturb} \
    --resume true \
    --output_dir ${exp_dir}/exp/${model_dir} \
    --config $asr_config \
    --ngpu $gpu_num \
    ...
```
* `task_name`: `asr` (Default), specify the task type of the current recipe
* `gpu_num`: `2` (Default), specify the number of GPUs for training. When `gpu_num > 1`, DistributedDataParallel (DDP, the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)) training will be enabled. Correspondingly, `CUDA_VISIBLE_DEVICES` should be set to specify which ids of GPUs will be used.
* `ngpu`: `2` (Default), specify the number of GPUs for training. When `ngpu > 1`, DistributedDataParallel (DDP, the detail can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)) training will be enabled. Correspondingly, `CUDA_VISIBLE_DEVICES` should be set to specify which ids of GPUs will be used.
* `use_preprocessor`: `true` (Default), specify whether to use pre-processing on each sample
* `token_list`: the path of token list for training
* `dataset_type`: `small` (Default). FunASR supports `small` dataset type for training small datasets. Besides, an optional iterable-style DataLoader based on [Pytorch Iterable-style DataPipes](https://pytorch.org/data/beta/torchdata.datapipes.iter.html) for large datasets is supported and users can specify `dataset_type=large` to enable it.
@@ -125,6 +144,7 @@
* `data_file_names`: `"wav.scp,text"` specify the speech and text file names for ASR
* `cmvn_file`: the path of cmvn file
* `resume`: `true`, whether to enable "checkpoint training"
* `output_dir`: the path for saving training results
* `config`: the path of configuration file, which is usually a YAML file in `conf` directory. In FunASR, the parameters of the training, including model, optimization, dataset, etc., can also be set in this file. Note that if the same parameters are specified in both recipe and config file, the parameters of recipe will be employed
### Stage 5: Decoding
@@ -156,25 +176,25 @@
## Change settings
Here we explain how to perform common custom settings, which can help users to modify scripts according to their own needs.
* Training with specified GPUs
### Training with specified GPUs
For example, if users want to use 2 GPUs with id `2` and `3, users can run the following command:
For example, if users want to use 2 GPUs with id `2` and `3`, users can run the following command:
```sh
. ./run.sh --CUDA_VISIBLE_DEVICES "2,3" --gpu_num 2 
```
* Start from/Stop at a specified stage
### Start from/Stop at a specified stage
The recipe includes several stages. Users can start form or stop at any stage. For example, the following command achieves starting from the third stage and stopping at the fifth stage:
```sh
. ./run.sh --stage 3 --stop_stage 5
```
* Training Steps
### Specify total training steps
FunASR supports two parameters to specify the training steps, namely `max_epoch` and `max_update`. `max_epoch` indicates the total training epochs while `max_update` indicates the total training steps. If these two parameters are specified at the same time, once the training reaches any one of these two parameters, the training will be stopped.
* Change the configuration of the model
### Change the configuration of the model
The configuration of the model is set in the config file `conf/train_*.yaml`. Specifically, the default encoder configuration of paraformer is as follows:
```
@@ -199,28 +219,49 @@
```
Users can change the encoder configuration by modify these values. For example, if users want to use an encoder with 16 conformer blocks and each block has 8 attention heads, users just need to change `num_blocks` from 12 to 16 and change `attention_heads` from 4 to 8. Besides, the batch_size, learning rate and other training hyper-parameters are also set in this config file. To change these hyper-parameters, users just need to directly change the corresponding values in this file. For example, the default learning rate is `0.0005`. If users want to change the learning rate to 0.0002, set the value of lr as `lr: 0.0002`.
* Use different input data type
### Change different input data type
FunASR supports different input data types, including `sound`, `kaldi_ark`, `npy`, `text` and `text_int`. Users can specify any number and any type of input, which is achieved by `data_file_names` (in `run.sh`), `data_names` and `data_types` (in config file). For example, ASR task usually requires speech and the corresponding transcripts as input. If speech is saved as raw audio (such as wav format) and transcripts are saved as text format, users need to set `data_file_names=wav.scp,text` (any name is allowed, denoting wav list and text list), set `data_names=speech,text` and set `data_types=sound,text`. When the input type changes to FBank, users just need to modify `data_types=kaldi_ark,text`.
FunASR supports different input data types, including `sound`, `kaldi_ark`, `npy`, `text` and `text_int`. Users can specify any number and any type of input, which is achieved by `data_names` and `data_types` (in `config/train_*.yaml`). For example, ASR task usually requires speech and the transcripts as input. In FunASR, by default, speech is saved as raw audio (such as wav format) and transcripts are saved as text format. Correspondingly, `data_names` and `data_types` are set as follows (seen in `config/train_*.yaml`):
```text
dataset_conf:
    data_names: speech,text
    data_types: sound,text
    ...
```
When the input type changes to FBank, users just need to modify as `data_types: kaldi_ark,text` in the config file. Note `data_file_names` used in `train.py` should also be changed to the new file name.
* How to start from pre-trained models
Users can start training from a pre-trained model by specifying the `init_param` parameter. Here `init_param` indicates the path of the pre-trained model. In addition to directly loading all the parameters from one pre-trained model, loading part of the parameters from different pre-trained models is supported. For example, to load encoder parameters from the pre-trained model A and decoder parameters from the pre-trained model B, users can set `init_param` twice as follows:
```sh
train.py ... --init_param ${model_A_path}:encoder  --init_param ${model_B_path}:decoder ...
### How to resume training process
FunASR supports resuming training as follows:
```shell
train.py ... --resume true ...
```
* How to freeze part model parameters
### How to transfer / fine-tuning from pre-trained models
FunASR supports transferring / fine-tuning from a pre-trained model by specifying the `init_param` parameter. The usage format is as follows:
```shell
train.py ... --init_param <file_path>:<src_key>:<dst_key>:<exclude_keys>  ..
```
For example, the following command achieves loading all pretrained parameters starting from decoder except decoder.embed and set it to model.decoder2:
```shell
train.py ... --init_param model.pb:decoder:decoder2:decoder.embed  ...
```
Besides, loading parameters from multiple pre-trained models is supported. For example, the following command achieves loading encoder parameters from the pre-trained model1 and decoder parameters from the pre-trained model2:
```sh
train.py ... --init_param model1.pb:encoder  --init_param model2.pb:decoder ...
```
### How to freeze part of the model parameters
In certain situations, users may want to fix part of the model parameters update the rest model parameters. FunASR employs `freeze_param` to achieve this. For example, to fix all parameters like `encoder.*`, users need to set `freeze_param ` as follows:
```sh
train.py ... --freeze_param encoder ...
```
* ModelScope Usage
### ModelScope Usage
Users can use ModelScope for inference and fine-tuning based on a trained academic model. To achieve this, users need to run the stage 6 in the script. In this stage, relevant files required by ModelScope will be generated automatically. Users can then use the corresponding ModelScope interface by replacing the model name with the local trained model path. For the detailed usage of the ModelScope interface, please refer to [ModelScope Usage](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html).
* Decoding by CPU or GPU
### Decoding by CPU or GPU
We support CPU and GPU decoding. For CPU decoding, set `gpu_inference=false` and `njob` to specific the total number of CPU jobs. For GPU decoding, first set `gpu_inference=true`. Then set `gpuid_list` to specific which GPUs for decoding and `njob` to specific the number of decoding jobs on each GPU.
docs/model_zoo/modelscope_models.md
@@ -1,7 +1,7 @@
# Pretrained Models on ModelScope
## Model License
You are free to use, copy, modify, and share FunASR under the conditions of this agreement. You should indicate the code and model source and author information when using, copying, modifying and sharing FunASR. To upload the FunASR industrial model to any third-party platform for download and use, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com). Full license could see [license](https://github.com/alibaba-damo-academy/FunASR/blob/main/LICENSE)
You are free to use, copy, modify, and share FunASR models under the conditions of this agreement. You should indicate the model source and author information when using, copying, modifying and sharing FunASR models. To upload the FunASR models to other third-party platforms for download, an additional license is required, which can be applied for free by sending an email to the official email address (funasr@list.alibaba-inc.com). Full model license could see [license](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE)
## Model Zoo
egs_modelscope/asr/TEMPLATE/README.md
@@ -24,7 +24,8 @@
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.4'
    model_revision='v1.0.6',
    mode='paraformer_streaming'
    )
import soundfile
speech, sample_rate = soundfile.read("example/asr_example.wav")
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/README.md
New file
@@ -0,0 +1 @@
../../TEMPLATE/README.md
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
@@ -1,39 +1,12 @@
import os
import logging
import torch
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
os.environ["MODELSCOPE_CACHE"] = "./"
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.4'
    model_revision='v1.0.6',
    mode="paraformer_fake_streaming"
)
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online")
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
speech_length = speech.shape[0]
sample_offset = 0
chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
stride_size =  chunk_size[1] * 960
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
    if sample_offset + stride_size >= speech_length - 1:
        stride_size = speech_length - sample_offset
        param_dict["is_final"] = True
    rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
                                    param_dict=param_dict)
    if len(rec_result) != 0:
        final_result += rec_result['text']
        print(rec_result)
print(final_result)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
New file
@@ -0,0 +1,40 @@
import os
import logging
import torch
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
os.environ["MODELSCOPE_CACHE"] = "./"
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    mode="paraformer_streaming"
)
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online")
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
speech_length = speech.shape[0]
sample_offset = 0
chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
stride_size =  chunk_size[1] * 960
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
    if sample_offset + stride_size >= speech_length - 1:
        stride_size = speech_length - sample_offset
        param_dict["is_final"] = True
    rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
                                    param_dict=param_dict)
    if len(rec_result) != 0:
        final_result += rec_result['text']
        print(rec_result)
print(final_result)
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
New file
@@ -0,0 +1,37 @@
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params.data_path)
    kwargs = dict(
        model=params.model,
        model_revision='v1.0.6',
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
if __name__ == '__main__':
    params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", data_path="./data")
    params.output_dir = "./checkpoint"              # m模型保存路径
    params.data_path = "./example_data/"            # 数据路径
    params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 1000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 20                           # 最大训练轮数
    params.lr = 0.00005                             # 设置学习率
    modelscope_finetune(params)
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
New file
@@ -0,0 +1,32 @@
import os
import shutil
import argparse
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
def modelscope_infer(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model=args.model,
        output_dir=args.output_dir,
        batch_size=args.batch_size,
        model_revision='v1.0.6',
        mode="paraformer_fake_streaming",
        param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
    )
    inference_pipeline(audio_in=args.audio_in)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
    parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
    parser.add_argument('--output_dir', type=str, default="./results/")
    parser.add_argument('--decoding_mode', type=str, default="normal")
    parser.add_argument('--model_revision', type=str, default=None)
    parser.add_argument('--mode', type=str, default=None)
    parser.add_argument('--hotword_txt', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--gpuid', type=str, default="0")
    args = parser.parse_args()
    modelscope_infer(args)
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
New file
@@ -0,0 +1,104 @@
#!/usr/bin/env bash
set -e
set -u
set -o pipefail
stage=1
stop_stage=2
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
data_dir="./data/test"
output_dir="./results"
batch_size=32
gpu_inference=true    # whether to perform gpu decoding
gpuid_list="0,1"    # set gpus, e.g., gpuid_list="0,1"
njob=32    # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
checkpoint_dir=
checkpoint_name="valid.cer_ctc.ave.pb"
. utils/parse_options.sh || exit 1;
if ${gpu_inference} == "true"; then
    nj=$(echo $gpuid_list | awk -F "," '{print NF}')
else
    nj=$njob
    batch_size=1
    gpuid_list=""
    for JOB in $(seq ${nj}); do
        gpuid_list=$gpuid_list"-1,"
    done
fi
mkdir -p $output_dir/split
split_scps=""
for JOB in $(seq ${nj}); do
    split_scps="$split_scps $output_dir/split/wav.$JOB.scp"
done
perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
if [ -n "${checkpoint_dir}" ]; then
  python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
  model=${checkpoint_dir}/${model}
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
    echo "Decoding ..."
    gpuid_list_array=(${gpuid_list//,/ })
    for JOB in $(seq ${nj}); do
        {
        id=$((JOB-1))
        gpuid=${gpuid_list_array[$id]}
        mkdir -p ${output_dir}/output.$JOB
        python infer.py \
            --model ${model} \
            --audio_in ${output_dir}/split/wav.$JOB.scp \
            --output_dir ${output_dir}/output.$JOB \
            --batch_size ${batch_size} \
            --gpuid ${gpuid}
            --mode "paraformer_fake_streaming"
        }&
    done
    wait
    mkdir -p ${output_dir}/1best_recog
    for f in token score text; do
        if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then
          for i in $(seq "${nj}"); do
              cat "${output_dir}/output.${i}/1best_recog/${f}"
          done | sort -k1 >"${output_dir}/1best_recog/${f}"
        fi
    done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
    echo "Computing WER ..."
    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
    tail -n 3 ${output_dir}/1best_recog/text.cer
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
    echo "SpeechIO TIOBE textnorm"
    echo "$0 --> Normalizing REF text ..."
    ./utils/textnorm_zh.py \
        --has_key --to_upper \
        ${data_dir}/text \
        ${output_dir}/1best_recog/ref.txt
    echo "$0 --> Normalizing HYP text ..."
    ./utils/textnorm_zh.py \
        --has_key --to_upper \
        ${output_dir}/1best_recog/text.proc \
        ${output_dir}/1best_recog/rec.txt
    grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt
    echo "$0 --> computing WER/CER and alignment ..."
    ./utils/error_rate_zh \
        --tokenizer char \
        --ref ${output_dir}/1best_recog/ref.txt \
        --hyp ${output_dir}/1best_recog/rec_non_empty.txt \
        ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt
    rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt
fi
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/utils
New file
@@ -0,0 +1 @@
../../TEMPLATE/utils/
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/README.md
New file
@@ -0,0 +1 @@
../../TEMPLATE/README.md
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py
@@ -1,39 +1,12 @@
import os
import logging
import torch
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
os.environ["MODELSCOPE_CACHE"] = "./"
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.4'
    model_revision='v1.0.6',
    mode="paraformer_fake_streaming"
)
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
speech_length = speech.shape[0]
sample_offset = 0
chunk_size = [8, 8, 4] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
stride_size =  chunk_size[1] * 960
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
    if sample_offset + stride_size >= speech_length - 1:
        stride_size = speech_length - sample_offset
        param_dict["is_final"] = True
    rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
                                    param_dict=param_dict)
    if len(rec_result) != 0:
        final_result += rec_result['text']
        print(rec_result)
print(final_result.strip())
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py
New file
@@ -0,0 +1,40 @@
import os
import logging
import torch
import soundfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
os.environ["MODELSCOPE_CACHE"] = "./"
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.6',
    mode="paraformer_streaming"
)
model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
speech_length = speech.shape[0]
sample_offset = 0
chunk_size = [8, 8, 4] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
stride_size =  chunk_size[1] * 960
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
final_result = ""
for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
    if sample_offset + stride_size >= speech_length - 1:
        stride_size = speech_length - sample_offset
        param_dict["is_final"] = True
    rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
                                    param_dict=param_dict)
    if len(rec_result) != 0:
        final_result += rec_result['text']
        print(rec_result)
print(final_result)
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py
New file
@@ -0,0 +1,37 @@
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params.data_path)
    kwargs = dict(
        model=params.model,
        model_revision='v1.0.6',
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
if __name__ == '__main__':
    params = modelscope_args(model="damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", data_path="./data")
    params.output_dir = "./checkpoint"              # m模型保存路径
    params.data_path = "./example_data/"            # 数据路径
    params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 1000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 20                           # 最大训练轮数
    params.lr = 0.00005                             # 设置学习率
    modelscope_finetune(params)
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
New file
@@ -0,0 +1,32 @@
import os
import shutil
import argparse
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
def modelscope_infer(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model=args.model,
        output_dir=args.output_dir,
        batch_size=args.batch_size,
        model_revision='v1.0.6',
        mode="paraformer_fake_streaming",
        param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
    )
    inference_pipeline(audio_in=args.audio_in)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
    parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
    parser.add_argument('--output_dir', type=str, default="./results/")
    parser.add_argument('--decoding_mode', type=str, default="normal")
    parser.add_argument('--model_revision', type=str, default=None)
    parser.add_argument('--mode', type=str, default=None)
    parser.add_argument('--hotword_txt', type=str, default=None)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--gpuid', type=str, default="0")
    args = parser.parse_args()
    modelscope_infer(args)
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh
New file
@@ -0,0 +1,104 @@
#!/usr/bin/env bash
set -e
set -u
set -o pipefail
stage=1
stop_stage=2
model="damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
data_dir="./data/test"
output_dir="./results"
batch_size=32
gpu_inference=true    # whether to perform gpu decoding
gpuid_list="0,1"    # set gpus, e.g., gpuid_list="0,1"
njob=32    # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
checkpoint_dir=
checkpoint_name="valid.cer_ctc.ave.pb"
. utils/parse_options.sh || exit 1;
if ${gpu_inference} == "true"; then
    nj=$(echo $gpuid_list | awk -F "," '{print NF}')
else
    nj=$njob
    batch_size=1
    gpuid_list=""
    for JOB in $(seq ${nj}); do
        gpuid_list=$gpuid_list"-1,"
    done
fi
mkdir -p $output_dir/split
split_scps=""
for JOB in $(seq ${nj}); do
    split_scps="$split_scps $output_dir/split/wav.$JOB.scp"
done
perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
if [ -n "${checkpoint_dir}" ]; then
  python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
  model=${checkpoint_dir}/${model}
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
    echo "Decoding ..."
    gpuid_list_array=(${gpuid_list//,/ })
    for JOB in $(seq ${nj}); do
        {
        id=$((JOB-1))
        gpuid=${gpuid_list_array[$id]}
        mkdir -p ${output_dir}/output.$JOB
        python infer.py \
            --model ${model} \
            --audio_in ${output_dir}/split/wav.$JOB.scp \
            --output_dir ${output_dir}/output.$JOB \
            --batch_size ${batch_size} \
            --gpuid ${gpuid}
            --mode "paraformer_fake_streaming"
        }&
    done
    wait
    mkdir -p ${output_dir}/1best_recog
    for f in token score text; do
        if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then
          for i in $(seq "${nj}"); do
              cat "${output_dir}/output.${i}/1best_recog/${f}"
          done | sort -k1 >"${output_dir}/1best_recog/${f}"
        fi
    done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
    echo "Computing WER ..."
    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
    tail -n 3 ${output_dir}/1best_recog/text.cer
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
    echo "SpeechIO TIOBE textnorm"
    echo "$0 --> Normalizing REF text ..."
    ./utils/textnorm_zh.py \
        --has_key --to_upper \
        ${data_dir}/text \
        ${output_dir}/1best_recog/ref.txt
    echo "$0 --> Normalizing HYP text ..."
    ./utils/textnorm_zh.py \
        --has_key --to_upper \
        ${output_dir}/1best_recog/text.proc \
        ${output_dir}/1best_recog/rec.txt
    grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt
    echo "$0 --> computing WER/CER and alignment ..."
    ./utils/error_rate_zh \
        --tokenizer char \
        --ref ${output_dir}/1best_recog/ref.txt \
        --hyp ${output_dir}/1best_recog/rec_non_empty.txt \
        ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt
    rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt
fi
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/utils
New file
@@ -0,0 +1 @@
../../TEMPLATE/utils/
funasr/bin/asr_infer.py
@@ -305,6 +305,7 @@
            nbest: int = 1,
            frontend_conf: dict = None,
            hotword_list_or_file: str = None,
            decoding_ind: int = 0,
            **kwargs,
    ):
        assert check_argument_types()
@@ -415,6 +416,7 @@
        self.nbest = nbest
        self.frontend = frontend
        self.encoder_downsampling_factor = 1
        self.decoding_ind = decoding_ind
        if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
            self.encoder_downsampling_factor = 4
@@ -452,7 +454,7 @@
        batch = to_device(batch, device=self.device)
        # b. Forward Encoder
        enc, enc_len = self.asr_model.encode(**batch)
        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
        if isinstance(enc, tuple):
            enc = enc[0]
        # assert len(enc) == 1, len(enc)
funasr/bin/asr_inference_launch.py
@@ -1638,6 +1638,8 @@
        return inference_uniasr(**kwargs)
    elif mode == "paraformer":
        return inference_paraformer(**kwargs)
    elif mode == "paraformer_fake_streaming":
        return inference_paraformer(**kwargs)
    elif mode == "paraformer_streaming":
        return inference_paraformer_online(**kwargs)
    elif mode.startswith("paraformer_vad"):
@@ -1920,4 +1922,4 @@
if __name__ == "__main__":
    main()
    main()
funasr/bin/build_trainer.py
@@ -23,6 +23,8 @@
        from funasr.tasks.asr import ASRTask as ASRTask
    elif mode == "paraformer":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "paraformer_streaming":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "paraformer_vad_punc":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "uniasr":
funasr/build_utils/build_asr_model.py
@@ -23,7 +23,7 @@
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
@@ -82,6 +82,7 @@
        asr=ASRModel,
        uniasr=UniASR,
        paraformer=Paraformer,
        paraformer_online=ParaformerOnline,
        paraformer_bert=ParaformerBert,
        bicif_paraformer=BiCifParaformer,
        contextual_paraformer=ContextualParaformer,
@@ -293,7 +294,7 @@
            token_list=token_list,
            **args.model_conf,
        )
    elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
funasr/models/decoder/sanm_decoder.py
@@ -935,6 +935,7 @@
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        chunk_mask: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
@@ -955,9 +956,13 @@
        """
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        if chunk_mask is not None:
            memory_mask = memory_mask * chunk_mask
            if tgt_mask.size(1) != memory_mask.size(1):
                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.decoders(
funasr/models/e2e_asr_paraformer.py
@@ -153,6 +153,7 @@
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
            decoding_ind: int = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
@@ -160,6 +161,7 @@
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
                decoding_ind: int
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
@@ -176,7 +178,11 @@
        speech = speech[:, :speech_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if hasattr(self.encoder, "overlap_chunk_cls"):
            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
@@ -272,12 +278,13 @@
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
            self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # 1. Extract feats
@@ -299,11 +306,25 @@
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                feats, feats_lengths, ctc=self.ctc
            )
            if hasattr(self.encoder, "overlap_chunk_cls"):
                encoder_out, encoder_out_lens, _ = self.encoder(
                    feats, feats_lengths, ctc=self.ctc, ind=ind
                )
                encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                            encoder_out_lens,
                                                                                            chunk_outs=None)
            else:
                encoder_out, encoder_out_lens, _ = self.encoder(
                    feats, feats_lengths, ctc=self.ctc
                )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
            if hasattr(self.encoder, "overlap_chunk_cls"):
                encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
                encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                            encoder_out_lens,
                                                                                            chunk_outs=None)
            else:
                encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
@@ -592,9 +613,137 @@
    """
    def __init__(
            self, *args, **kwargs,
            self,
            vocab_size: int,
            token_list: Union[Tuple[str, ...], List[str]],
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            encoder: AbsEncoder,
            decoder: AbsDecoder,
            ctc: CTC,
            ctc_weight: float = 0.5,
            interctc_weight: float = 0.0,
            ignore_id: int = -1,
            blank_id: int = 0,
            sos: int = 1,
            eos: int = 2,
            lsm_weight: float = 0.0,
            length_normalized_loss: bool = False,
            report_cer: bool = True,
            report_wer: bool = True,
            sym_space: str = "<space>",
            sym_blank: str = "<blank>",
            extract_feats_in_collect_stats: bool = True,
            predictor=None,
            predictor_weight: float = 0.0,
            predictor_bias: int = 0,
            sampling_ratio: float = 0.2,
            decoder_attention_chunk_type: str = 'chunk',
            share_embedding: bool = False,
            preencoder: Optional[AbsPreEncoder] = None,
            postencoder: Optional[AbsPostEncoder] = None,
            use_1st_decoder_loss: bool = False,
    ):
        super().__init__(*args, **kwargs)
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
        super().__init__(
            vocab_size=vocab_size,
            token_list=token_list,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            preencoder=preencoder,
            encoder=encoder,
            postencoder=postencoder,
            decoder=decoder,
            ctc=ctc,
            ctc_weight=ctc_weight,
            interctc_weight=interctc_weight,
            ignore_id=ignore_id,
            blank_id=blank_id,
            sos=sos,
            eos=eos,
            lsm_weight=lsm_weight,
            length_normalized_loss=length_normalized_loss,
            report_cer=report_cer,
            report_wer=report_wer,
            sym_space=sym_space,
            sym_blank=sym_blank,
            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
            predictor=predictor,
            predictor_weight=predictor_weight,
            predictor_bias=predictor_bias,
            sampling_ratio=sampling_ratio,
        )
        # note that eos is the same as sos (equivalent ID)
        self.blank_id = blank_id
        self.sos = vocab_size - 1 if sos is None else sos
        self.eos = vocab_size - 1 if eos is None else eos
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.interctc_weight = interctc_weight
        self.token_list = token_list.copy()
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.postencoder = postencoder
        self.encoder = encoder
        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
        if self.encoder.interctc_use_conditioning:
            self.encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, self.encoder.output_size()
            )
        self.error_calculator = None
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        if report_cer or report_wer:
            self.error_calculator = ErrorCalculator(
                token_list, sym_space, sym_blank, report_cer, report_wer
            )
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
        self.predictor = predictor
        self.predictor_weight = predictor_weight
        self.predictor_bias = predictor_bias
        self.sampling_ratio = sampling_ratio
        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
        self.step_cur = 0
        self.scama_mask = None
        if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
            from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
            self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
            self.decoder_attention_chunk_type = decoder_attention_chunk_type
        self.share_embedding = share_embedding
        if self.share_embedding:
            self.decoder.embed = None
        self.use_1st_decoder_loss = use_1st_decoder_loss
    def forward(
            self,
@@ -602,6 +751,7 @@
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
            decoding_ind: int = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
@@ -609,6 +759,7 @@
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
                decoding_ind: int
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
@@ -625,7 +776,11 @@
        speech = speech[:, :speech_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if hasattr(self.encoder, "overlap_chunk_cls"):
            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
@@ -638,8 +793,12 @@
        # 1. CTC branch
        if self.ctc_weight != 0.0:
            if hasattr(self.encoder, "overlap_chunk_cls"):
                encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                                    encoder_out_lens,
                                                                                                    chunk_outs=None)
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
                encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
            )
            # Collect CTC branch stats
@@ -652,8 +811,14 @@
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                if hasattr(self.encoder, "overlap_chunk_cls"):
                    encoder_out_ctc, encoder_out_lens_ctc = \
                        self.encoder.overlap_chunk_cls.remove_chunk(
                            intermediate_out,
                            encoder_out_lens,
                            chunk_outs=None)
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                    encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
@@ -672,7 +837,7 @@
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
            loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
@@ -684,8 +849,12 @@
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
        if self.use_1st_decoder_loss and pre_loss_att is not None:
            loss = loss + pre_loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
@@ -697,14 +866,67 @@
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
            # 2. Data augmentation
            if self.specaug is not None and self.training:
                feats, feats_lengths = self.specaug(feats, feats_lengths)
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                feats, feats_lengths, ctc=self.ctc, ind=ind
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        # Post-encoder, e.g. NLU
        if self.postencoder is not None:
            encoder_out, encoder_out_lens = self.postencoder(
                encoder_out, encoder_out_lens
            )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        assert encoder_out.size(1) <= encoder_out_lens.max(), (
            encoder_out.size(),
            encoder_out_lens.max(),
        )
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def encode_chunk(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
<<<<<<< HEAD
=======
>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
@@ -750,11 +972,240 @@
        return encoder_out, torch.tensor([encoder_out.size(1)])
    def _calc_att_predictor_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        if self.predictor_bias == 1:
            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
                                                                                           device=encoder_out.device,
                                                                                           batch_size=encoder_out.size(
                                                                                               0))
            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
                                                                                   batch_size=encoder_out.size(0))
            encoder_out = encoder_out * mask_shfit_chunk
        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
                                                                              ys_pad,
                                                                              encoder_out_mask,
                                                                              ignore_id=self.ignore_id,
                                                                              mask_chunk_predictor=mask_chunk_predictor,
                                                                              target_label_length=ys_pad_lens,
                                                                              )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
                                                                                             encoder_out_lens)
        scama_mask = None
        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
            encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
            attention_chunk_center_bias = 0
            attention_chunk_size = encoder_chunk_size
            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
                get_mask_shift_att_chunk_decoder(None,
                                                 device=encoder_out.device,
                                                 batch_size=encoder_out.size(0)
                                                 )
            scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
                predictor_alignments=predictor_alignments,
                encoder_sequence_length=encoder_out_lens,
                chunk_size=1,
                encoder_chunk_size=encoder_chunk_size,
                attention_chunk_center_bias=attention_chunk_center_bias,
                attention_chunk_size=attention_chunk_size,
                attention_chunk_type=self.decoder_attention_chunk_type,
                step=None,
                predictor_mask_chunk_hopping=mask_chunk_predictor,
                decoder_att_look_back_factor=decoder_att_look_back_factor,
                mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
                target_length=ys_pad_lens,
                is_training=self.training,
            )
        elif self.encoder.overlap_chunk_cls is not None:
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                        encoder_out_lens,
                                                                                        chunk_outs=None)
        # 0. sampler
        decoder_out_1st = None
        pre_loss_att = None
        if self.sampling_ratio > 0.0:
            if self.step_cur < 2:
                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            if self.use_1st_decoder_loss:
                sematic_embeds, decoder_out_1st, pre_loss_att = \
                    self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
                                           ys_pad_lens, pre_acoustic_embeds, scama_mask)
            else:
                sematic_embeds, decoder_out_1st = \
                    self.sampler(encoder_out, encoder_out_lens, ys_pad,
                                 ys_pad_lens, pre_acoustic_embeds, scama_mask)
        else:
            if self.step_cur < 2:
                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds = pre_acoustic_embeds
        # 1. Forward decoder
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
        )
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        if decoder_out_1st is None:
            decoder_out_1st = decoder_out
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_pad)
        acc_att = th_accuracy(
            decoder_out_1st.view(-1, self.vocab_size),
            ys_pad,
            ignore_label=self.ignore_id,
        )
        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out_1st.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
            for li in range(bsz):
                target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
                if target_num > 0:
                    input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
            input_mask = input_mask.eq(1)
            input_mask = input_mask.masked_fill(~nonpad_positions, False)
            input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
            input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
        )
        pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        pred_tokens = decoder_out.argmax(-1)
        nonpad_positions = ys_pad.ne(self.ignore_id)
        seq_lens = (nonpad_positions).sum(1)
        same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
        input_mask = torch.ones_like(nonpad_positions)
        bsz, seq_len = ys_pad.size()
        for li in range(bsz):
            target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
            if target_num > 0:
                input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
        input_mask = input_mask.eq(1)
        input_mask = input_mask.masked_fill(~nonpad_positions, False)
        input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
            input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
    def calc_predictor(self, encoder_out, encoder_out_lens):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
                                                                                           device=encoder_out.device,
                                                                                           batch_size=encoder_out.size(
                                                                                               0))
            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
                                                                                   batch_size=encoder_out.size(0))
            encoder_out = encoder_out * mask_shfit_chunk
        pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
                                                                                           None,
                                                                                           encoder_out_mask,
                                                                                           ignore_id=self.ignore_id,
                                                                                           mask_chunk_predictor=mask_chunk_predictor,
                                                                                           target_label_length=None,
                                                                                           )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
                                                                                             encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
        scama_mask = None
        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
            encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
            attention_chunk_center_bias = 0
            attention_chunk_size = encoder_chunk_size
            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
                get_mask_shift_att_chunk_decoder(None,
                                                 device=encoder_out.device,
                                                 batch_size=encoder_out.size(0)
                                                 )
            scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
                predictor_alignments=predictor_alignments,
                encoder_sequence_length=encoder_out_lens,
                chunk_size=1,
                encoder_chunk_size=encoder_chunk_size,
                attention_chunk_center_bias=attention_chunk_center_bias,
                attention_chunk_size=attention_chunk_size,
                attention_chunk_type=self.decoder_attention_chunk_type,
                step=None,
                predictor_mask_chunk_hopping=mask_chunk_predictor,
                decoder_att_look_back_factor=decoder_att_look_back_factor,
                mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
                target_length=None,
                is_training=self.training,
            )
        self.scama_mask = scama_mask
        return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
    def calc_predictor_chunk(self, encoder_out, cache=None):
        pre_acoustic_embeds, pre_token_length = \
            self.predictor.forward_chunk(encoder_out, cache["encoder"])
        return pre_acoustic_embeds, pre_token_length
    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
        decoder_outs = self.decoder.forward_chunk(
@@ -1800,4 +2251,4 @@
                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                  var_dict_tf[name_tf].shape))
        return var_dict_torch_update
        return var_dict_torch_update
funasr/models/encoder/sanm_encoder.py
@@ -633,6 +633,8 @@
                self.embed = torch.nn.Linear(input_size, output_size)
        elif input_layer == "pe":
            self.embed = SinusoidalPositionEncoder()
        elif input_layer == "pe_online":
            self.embed = StreamSinusoidalPositionEncoder()
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        self.normalize_before = normalize_before
@@ -818,6 +820,59 @@
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
        if len(cache) == 0:
            return feats
        cache["feats"] = to_device(cache["feats"], device=feats.device)
        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
        return overlap_feats
    def forward_chunk(self,
                      xs_pad: torch.Tensor,
                      ilens: torch.Tensor,
                      cache: dict = None,
                      ctc: CTC = None,
                      ):
        xs_pad *= self.output_size() ** 0.5
        if self.embed is None:
            xs_pad = xs_pad
        else:
            xs_pad = self.embed(xs_pad, cache)
        if cache["tail_chunk"]:
            xs_pad = to_device(cache["feats"], device=xs_pad.device)
        else:
            xs_pad = self._add_overlap_chunk(xs_pad, cache)
        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        if len(self.interctc_layer_idx) == 0:
            encoder_outs = self.encoders(xs_pad, None, None, None, None)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        else:
            for layer_idx, encoder_layer in enumerate(self.encoders):
                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
                xs_pad, masks = encoder_outs[0], encoder_outs[1]
                if layer_idx + 1 in self.interctc_layer_idx:
                    encoder_out = xs_pad
                    # intermediate outputs are also normalized
                    if self.normalize_before:
                        encoder_out = self.after_norm(encoder_out)
                    intermediate_outs.append((layer_idx + 1, encoder_out))
                    if self.interctc_use_conditioning:
                        ctc_out = ctc.softmax(encoder_out)
                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), None, None
        return xs_pad, ilens, None
    def gen_tf2torch_map_dict(self):
        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
funasr/runtime/python/websocket/wss_srv_asr.py
@@ -58,7 +58,8 @@
    model=args.asr_model_online,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision='v1.0.4')
    model_revision='v1.0.6',
    mode='paraformer_streaming')
print("model loaded")
@@ -207,4 +208,4 @@
else:
    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
asyncio.get_event_loop().run_forever()