Merge pull request #363 from alibaba-damo-academy/main
update with main
100个文件已修改
7个文件已删除
19个文件已添加
1 文件已复制
7 文件已重命名
| | |
| | | [**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new) |
| | | | [**Highlights**](#highlights) |
| | | | [**Installation**](#installation) |
| | | | [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) |
| | | | [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html) |
| | | | [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C) |
| | | | [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations) |
| | | | [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime) |
| | | | [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) |
| | | | [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md) |
| | | | [**Contact**](#contact) |
| | | |
| | | |
| | |
| | | |
| | | ## Installation |
| | | |
| | | ``` sh |
| | | pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html |
| | | git clone https://github.com/alibaba/FunASR.git && cd FunASR |
| | | pip install --editable ./ |
| | | Install from pip |
| | | ```shell |
| | | pip install -U funasr |
| | | # For the users in China, you could install with the command: |
| | | # pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | ``` |
| | | |
| | | Or install from source code |
| | | |
| | | |
| | | ``` sh |
| | | git clone https://github.com/alibaba/FunASR.git && cd FunASR |
| | | pip install -e ./ |
| | | # For the users in China, you could install with the command: |
| | | # pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | |
| | | ``` |
| | | If you want to use the pretrained models in ModelScope, you should install the modelscope: |
| | | |
| | | ```shell |
| | | pip install -U modelscope |
| | | # For the users in China, you could install with the command: |
| | | # pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | ``` |
| | | |
| | | For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki) |
| | | |
| | | ## Usage |
| | | For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)) |
| | | [//]: # () |
| | | [//]: # (## Usage) |
| | | |
| | | [//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))) |
| | | |
| | | ## Contact |
| | | |
| | |
| | | ## Model Zoo |
| | | Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition). |
| | | |
| | | | Datasets | Hours | Model | Online/Offline | Language | Framework | Checkpoint | |
| | | |:-----:|:-----:|:--------------:|:--------------:| :---: | :---: | --- | |
| | | | Alibaba Speech Data | 60000 | Paraformer | Offline | CN | Pytorch |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | |
| | | | Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | |
| | | | Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | |
| | | | Alibaba Speech Data | 50000 | Paraformer | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online](http://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online/summary) | |
| | | | Alibaba Speech Data | 50000 | UniASR | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | |
| | | | Alibaba Speech Data | 50000 | UniASR | Offline | CN | Tensorflow |[speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | |
| | | | Alibaba Speech Data | 50000 | UniASR | Online | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online/summary) | |
| | | | Alibaba Speech Data | 50000 | UniASR | Offline | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline/summary) | |
| | | | Alibaba Speech Data | 20000 | UniASR | Online | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/summary) | |
| | | | Alibaba Speech Data | 20000 | UniASR | Offline | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/summary) | |
| | | | Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online/summary) | |
| | | | Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/summary) | |
| | | | Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) | |
| | | | Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) | |
| | | | Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/summary) | |
| | | | Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/summary) | |
| | | | Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) | |
| | | | Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) | |
| | | | AISHELL-1 | 178 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | |
| | | | AISHELL-2 | 1000 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell2-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell2-pytorch/summary) | |
| | | | AISHELL-1 | 178 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | |
| | | | AISHELL-2 | 1000 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | |
| | | | AISHELL-1 | 178 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | |
| | | | AISHELL-2 | 1000 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | |
| | | ### Speech Recognition Models |
| | | #### Paraformer Models |
| | | | Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes | |
| | | |:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------| |
| | | | [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s | |
| | | | [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav | |
| | | | [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. | |
| | | | [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s | |
| | | | [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input | |
| | | | [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition | |
| | | | [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | | |
| | | | [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | | |
| | | | [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | | |
| | | | [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | | |
| | | |
| | | |
| | | #### UniASR Models |
| | | | Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes | |
| | | |:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------| |
| | | | [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models | |
| | | | [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models | |
| | | | [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models | |
| | | | [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models | |
| | | | [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models | |
| | | |
| | | #### Conformer Models |
| | | #### Paraformer Models |
| | | | Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes | |
| | | |:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------| |
| | | | [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s | |
| | | | [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s | |
| | | |
| | | #### RNN-T Models |
| | | |
| | | ### Voice Activity Detection Models |
| | | |
| | | | Model Name | Training Data | Parameters | Sampling Rate | Notes | |
| | | |:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------| |
| | | | [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | | |
| | | | [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | | |
| | | |
| | | ### Punctuation Restoration Models |
| | | |
| | | | Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes | |
| | | |:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------| |
| | | | [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model | |
| | | | [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model | |
| | | |
| | | ### Language Models |
| | | |
| | | | Model Name | Training Data | Parameters | Vocab Size | Notes | |
| | | |:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------| |
| | | | [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | | |
| | | |
| | | ### Speaker Verification Models |
| | | |
| | | | Model Name | Training Data | Parameters | Vocab Size | Notes | |
| | | |:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------| |
| | | | [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (?hours) | 17.5M | 3465 | | |
| | | | [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (?hours) | 61M | 6135 | | |
| | | |
| | | ### Speaker diarization Models |
| | | |
| | | | Model Name | Training Data | Parameters | Notes | |
| | | |:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------| |
| | | | [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (?hours) | 40.5M | | |
| | | | [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (?hours) | 12M | | |
| | |
| | | if out_item['wrong'] > 0: |
| | | rst['wrong_sentences'] += 1 |
| | | cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') |
| | | cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n') |
| | | cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n') |
| | | cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n') |
| | | cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n') |
| | | |
| | | if rst['Wrd'] > 0: |
| | | rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) |
| | |
| | | # If text exists, compute CER |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(best_recog_path, "token") |
| | | text_proc_file = os.path.join(best_recog_path, "text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer")) |
| | | |
| | | |
| | |
| | | # computer CER if GT text is set |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | |
| | |
| | | # If text exists, compute CER |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(best_recog_path, "token") |
| | | text_proc_file = os.path.join(best_recog_path, "text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer")) |
| | | |
| | | |
| | |
| | | # computer CER if GT text is set |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | |
| | |
| | | |
| | | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then |
| | | echo "Computing WER ..." |
| | | python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc |
| | | python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref |
| | | cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${output_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer |
| | | tail -n 3 ${output_dir}/1best_recog/text.cer |
| | | fi |
| | |
| | | # computer CER if GT text is set |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | |
| | |
| | | |
| | | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then |
| | | echo "Computing WER ..." |
| | | python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc |
| | | python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref |
| | | cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${output_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer |
| | | tail -n 3 ${output_dir}/1best_recog/text.cer |
| | | fi |
| | |
| | | # computer CER if GT text is set |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | |
| | |
| | | batch_size=1 |
| | | ) |
| | | audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx)) |
| | | inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "offline"}) |
| | | |
| | | inference_pipline(audio_in=audio_in) |
| | | |
| | | def modelscope_infer(params): |
| | | # prepare for multi-GPU decoding |
| | |
| | | # If text exists, compute CER |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(best_recog_path, "token") |
| | | text_proc_file = os.path.join(best_recog_path, "text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer")) |
| | | |
| | | |
| | |
| | | import os |
| | | import shutil |
| | | |
| | | from multiprocessing import Pool |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | from funasr.utils.compute_wer import compute_wer |
| | | |
| | | |
| | | def modelscope_infer_after_finetune_core(model_dir, output_dir, split_dir, njob, idx): |
| | | output_dir_job = os.path.join(output_dir, "output.{}".format(idx)) |
| | | gpu_id = (int(idx) - 1) // njob |
| | | if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): |
| | | gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",") |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id]) |
| | | else: |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model=model_dir, |
| | | output_dir=output_dir_job, |
| | | batch_size=1 |
| | | ) |
| | | audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx)) |
| | | inference_pipeline(audio_in=audio_in) |
| | | |
| | | def modelscope_infer_after_finetune(params): |
| | | # prepare for decoding |
| | | # prepare for multi-GPU decoding |
| | | model_dir = params["model_dir"] |
| | | pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"]) |
| | | for file_name in params["required_files"]: |
| | | if file_name == "configuration.json": |
| | | with open(os.path.join(pretrained_model_path, file_name)) as f: |
| | | config_dict = json.load(f) |
| | | config_dict["model"]["am_model_name"] = params["decoding_model_name"] |
| | | with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f: |
| | | with open(os.path.join(model_dir, "configuration.json"), "w") as f: |
| | | json.dump(config_dict, f, indent=4, separators=(',', ': ')) |
| | | else: |
| | | shutil.copy(os.path.join(pretrained_model_path, file_name), |
| | | os.path.join(params["output_dir"], file_name)) |
| | | decoding_path = os.path.join(params["output_dir"], "decode_results") |
| | | if os.path.exists(decoding_path): |
| | | shutil.rmtree(decoding_path) |
| | | os.mkdir(decoding_path) |
| | | os.path.join(model_dir, file_name)) |
| | | ngpu = params["ngpu"] |
| | | njob = params["njob"] |
| | | output_dir = params["output_dir"] |
| | | if os.path.exists(output_dir): |
| | | shutil.rmtree(output_dir) |
| | | os.mkdir(output_dir) |
| | | split_dir = os.path.join(output_dir, "split") |
| | | os.mkdir(split_dir) |
| | | nj = ngpu * njob |
| | | wav_scp_file = os.path.join(params["data_dir"], "wav.scp") |
| | | with open(wav_scp_file) as f: |
| | | lines = f.readlines() |
| | | num_lines = len(lines) |
| | | num_job_lines = num_lines // nj |
| | | start = 0 |
| | | for i in range(nj): |
| | | end = start + num_job_lines |
| | | file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1))) |
| | | with open(file, "w") as f: |
| | | if i == nj - 1: |
| | | f.writelines(lines[start:]) |
| | | else: |
| | | f.writelines(lines[start:end]) |
| | | start = end |
| | | |
| | | # decoding |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model=params["output_dir"], |
| | | output_dir=decoding_path, |
| | | batch_size=1 |
| | | ) |
| | | audio_in = os.path.join(params["data_dir"], "wav.scp") |
| | | inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "offline"}) |
| | | p = Pool(nj) |
| | | for i in range(nj): |
| | | p.apply_async(modelscope_infer_after_finetune_core, |
| | | args=(model_dir, output_dir, split_dir, njob, str(i + 1))) |
| | | p.close() |
| | | p.join() |
| | | |
| | | # computer CER if GT text is set |
| | | # combine decoding results |
| | | best_recog_path = os.path.join(output_dir, "1best_recog") |
| | | os.mkdir(best_recog_path) |
| | | files = ["text", "token", "score"] |
| | | for file in files: |
| | | with open(os.path.join(best_recog_path, file), "w") as f: |
| | | for i in range(nj): |
| | | job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file) |
| | | with open(job_file) as f_job: |
| | | lines = f_job.readlines() |
| | | f.writelines(lines) |
| | | |
| | | # If text exists, compute CER |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | text_proc_file = os.path.join(best_recog_path, "token") |
| | | compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer")) |
| | | |
| | | if __name__ == '__main__': |
| | | params = {} |
| | | params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline" |
| | | params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"] |
| | | params["output_dir"] = "./checkpoint" |
| | | params["model_dir"] = "./checkpoint" |
| | | params["output_dir"] = "./results" |
| | | params["data_dir"] = "./data/test" |
| | | params["decoding_model_name"] = "20epoch.pb" |
| | | params["ngpu"] = 1 |
| | | params["njob"] = 1 |
| | | modelscope_infer_after_finetune(params) |
| | | |
| | |
| | | # If text exists, compute CER |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(best_recog_path, "token") |
| | | text_proc_file = os.path.join(best_recog_path, "text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer")) |
| | | |
| | | |
| | |
| | | # computer CER if GT text is set |
| | | text_in = os.path.join(params["data_dir"], "text") |
| | | if os.path.exists(text_in): |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/token") |
| | | text_proc_file = os.path.join(decoding_path, "1best_recog/text") |
| | | compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer")) |
| | | |
| | | |
| File was renamed from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py |
| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
copy from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
copy to egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
| File was copied from egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py |
| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | """ |
| | | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | |
| | | num_workers=0, |
| | | task=Tasks.speaker_diarization, |
| | | diar_model_config="sond.yaml", |
| | | model='damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch', |
| | | sv_model="damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch", |
| | | model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch', |
| | | sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch", |
| | | sv_model_revision="master", |
| | | ) |
| | | |
| | | # 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音 |
| | | audio_list = [ |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav" |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav", |
| | | "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav", |
| | | ] |
| | | |
| | | results = inference_diar_pipline(audio_in=audio_list) |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | finish_count += 1 |
| | | # asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | |
| | | logging.info("decoding, utt: {}, predictions: {}".format(key, text)) |
| | | rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) |
| | |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torchaudio |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | |
| | | ): |
| | | |
| | | # 3. Build data-iterator |
| | | if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": |
| | | raw_inputs = _load_bytes(data_path_and_name_and_type[0]) |
| | | raw_inputs = torch.tensor(raw_inputs) |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, np.ndarray): |
| | | raw_inputs = torch.tensor(raw_inputs) |
| | | is_final = False |
| | | if param_dict is not None and "cache" in param_dict: |
| | | cache = param_dict["cache"] |
| | | if param_dict is not None and "is_final" in param_dict: |
| | | is_final = param_dict["is_final"] |
| | | |
| | | if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": |
| | | raw_inputs = _load_bytes(data_path_and_name_and_type[0]) |
| | | raw_inputs = torch.tensor(raw_inputs) |
| | | if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": |
| | | raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] |
| | | is_final = True |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, np.ndarray): |
| | | raw_inputs = torch.tensor(raw_inputs) |
| | | # 7 .Start for-loop |
| | | # FIXME(kamo): The output format should be discussed about |
| | | asr_result_list = [] |
| | |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["vad"][key] = "{}".format(vadsegments) |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | ibest_writer["text_with_punc"][key] = text_postprocessed_punc |
| | | if time_stamp_postprocessed is not None: |
| | | ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) |
| | |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["vad"][key] = "{}".format(vadsegments) |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | ibest_writer["text_with_punc"][key] = text_postprocessed_punc |
| | | if time_stamp_postprocessed is not None: |
| | | ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) |
| | |
| | | ibest_writer["rtf"][key] = rtf_cur |
| | | |
| | | if text is not None: |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | # asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | |
| | | logging.info("decoding, utt: {}, predictions: {}".format(key, text)) |
| | | rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) |
| | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | | |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | return asr_result_list |
| | | |
| | | return _forward |
| | |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | if text is not None: |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | return asr_result_list |
| | | |
| | | return _forward |
| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | data = { |
| | | "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), |
| | | "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), |
| | | "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')), |
| | | "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')), |
| | | } |
| | | data = to_device(data, self.device) |
| | | y, _ = self.wrapped_model(**data) |
| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import torch |
| | | torch.set_num_threads(1) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | # NOTE(kamo): SoundScpReader doesn't support pipe-fashion |
| | | # like Kaldi e.g. "cat a.wav |". |
| | | # NOTE(kamo): The audio signal is normalized to [-1,1] range. |
| | | loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False) |
| | | loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate) |
| | | |
| | | # SoundScpReader.__getitem__() returns Tuple[int, ndarray], |
| | | # but ndarray is desired, so Adapter class is inserted here |
| | |
| | | length = len(text) |
| | | for i in range(length): |
| | | x = text[i] |
| | | if i == length-1 and "punc" in data and text[i].startswith("vad:"): |
| | | vad = x[-1][4:] |
| | | if i == length-1 and "punc" in data and x.startswith("vad:"): |
| | | vad = x[4:] |
| | | if len(vad) == 0: |
| | | vad = -1 |
| | | else: |
| | |
| | | ) -> Dict[str, np.ndarray]: |
| | | for i in range(self.num_tokenizer): |
| | | text_name = self.text_name[i] |
| | | #import pdb; pdb.set_trace() |
| | | if text_name in data and self.tokenizer[i] is not None: |
| | | text = data[text_name] |
| | | text = self.text_cleaner(text) |
| | |
| | | data[self.vad_name] = np.array([vad], dtype=np.int64) |
| | | text_ints = self.token_id_converter[i].tokens2ids(tokens) |
| | | data[text_name] = np.array(text_ints, dtype=np.int64) |
| | | return data |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | |
| | | |
| | | ## Install modelscope and funasr |
| | | |
| | | The installation is the same as [funasr](../../README.md) |
| | | The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation) |
| | | |
| | | ## Export model |
| | | `Tips`: torch>=1.11.0 |
| | |
| | | |
| | | def export(self, |
| | | tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | mode: str = 'paraformer', |
| | | mode: str = None, |
| | | ): |
| | | |
| | | model_dir = tag_name |
| | | if model_dir.startswith('damo/'): |
| | | if model_dir.startswith('damo'): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) |
| | | asr_train_config = os.path.join(model_dir, 'config.yaml') |
| | | asr_model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | |
| | | if mode is None: |
| | | import json |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | with open(json_file, 'r') as f: |
| | | config_data = json.load(f) |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if config_data['task'] == "punctuation": |
| | | mode = config_data['model']['punc_model_config']['mode'] |
| | | else: |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if mode.startswith('paraformer'): |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | elif mode.startswith('uniasr'): |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | config = os.path.join(model_dir, 'config.yaml') |
| | | model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | config, model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('offline'): |
| | | from funasr.tasks.vad import VADTask |
| | | config = os.path.join(model_dir, 'vad.yaml') |
| | | model_file = os.path.join(model_dir, 'vad.pb') |
| | | cmvn_file = os.path.join(model_dir, 'vad.mvn') |
| | | |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | model, vad_infer_args = VADTask.build_model_from_file( |
| | | config, model_file, cmvn_file=cmvn_file, device='cpu' |
| | | ) |
| | | self.export_config["feats_dim"] = 400 |
| | | self.frontend = model.frontend |
| | | elif mode.startswith('punc'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | elif mode.startswith('punc_VadRealtime'): |
| | | from funasr.tasks.punctuation import PunctuationTask as PUNCTask |
| | | punc_train_config = os.path.join(model_dir, 'config.yaml') |
| | | punc_model_file = os.path.join(model_dir, 'punc.pb') |
| | | model, punc_train_args = PUNCTask.build_model_from_file( |
| | | punc_train_config, punc_model_file, 'cpu' |
| | | ) |
| | | self._export(model, tag_name) |
| | | |
| | | |
| New file |
| | |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export |
| | | from funasr.models.encoder.sanm_encoder import SANMVadEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export |
| | | |
| | | class CT_Transformer(nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | model_name='punc_model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | self.embed = model.embed |
| | | self.decoder = model.decoder |
| | | # self.model = model |
| | | self.feats_dim = self.embed.embedding_dim |
| | | self.num_embeddings = self.embed.num_embeddings |
| | | self.model_name = model_name |
| | | |
| | | if isinstance(model.encoder, SANMEncoder): |
| | | self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) |
| | | else: |
| | | assert False, "Only support samn encode." |
| | | |
| | | def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(inputs) |
| | | # mask = self._target_mask(input) |
| | | h, _ = self.encoder(x, text_lengths) |
| | | y = self.decoder(h) |
| | | return y |
| | | |
| | | def get_dummy_inputs(self): |
| | | length = 120 |
| | | text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)) |
| | | text_lengths = torch.tensor([length-20, length], dtype=torch.int32) |
| | | return (text_indexes, text_lengths) |
| | | |
| | | def get_input_names(self): |
| | | return ['inputs', 'text_lengths'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'inputs': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | }, |
| | | 'text_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| | | |
| | | |
| | | class CT_Transformer_VadRealtime(nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | model_name='punc_model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | |
| | | self.embed = model.embed |
| | | if isinstance(model.encoder, SANMVadEncoder): |
| | | self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx) |
| | | else: |
| | | assert False, "Only support samn encode." |
| | | self.decoder = model.decoder |
| | | self.model_name = model_name |
| | | |
| | | |
| | | |
| | | def forward(self, inputs: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | vad_indexes: torch.Tensor, |
| | | sub_masks: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(inputs) |
| | | # mask = self._target_mask(input) |
| | | h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks) |
| | | y = self.decoder(h) |
| | | return y |
| | | |
| | | def with_vad(self): |
| | | return True |
| | | |
| | | def get_dummy_inputs(self): |
| | | length = 120 |
| | | text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)) |
| | | text_lengths = torch.tensor([length], dtype=torch.int32) |
| | | vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :] |
| | | sub_masks = torch.ones(length, length, dtype=torch.float32) |
| | | sub_masks = torch.tril(sub_masks).type(torch.float32) |
| | | return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :]) |
| | | |
| | | def get_input_names(self): |
| | | return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'inputs': { |
| | | 1: 'feats_length' |
| | | }, |
| | | 'vad_masks': { |
| | | 2: 'feats_length1', |
| | | 3: 'feats_length2' |
| | | }, |
| | | 'sub_masks': { |
| | | 2: 'feats_length1', |
| | | 3: 'feats_length2' |
| | | }, |
| | | 'logits': { |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| | |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | |
| | | from funasr.models.e2e_vad import E2EVadModel |
| | | from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export |
| | | |
| | | def get_model(model, export_config=None): |
| | | if isinstance(model, BiCifParaformer): |
| | | return BiCifParaformer_export(model, **export_config) |
| | | elif isinstance(model, Paraformer): |
| | | return Paraformer_export(model, **export_config) |
| | | elif isinstance(model, E2EVadModel): |
| | | return E2EVadModel_export(model, **export_config) |
| | | elif isinstance(model, PunctuationModel): |
| | | if isinstance(model.punc_model, TargetDelayTransformer): |
| | | return CT_Transformer_export(model.punc_model, **export_config) |
| | | elif isinstance(model.punc_model, VadRealtimeTransformer): |
| | | return CT_Transformer_VadRealtime_export(model.punc_model, **export_config) |
| | | else: |
| | | raise "Funasr does not support the given model type currently." |
| | | raise "Funasr does not support the given model type currently." |
| | |
| | | |
| | | class Paraformer(nn.Module): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2206.08317 |
| | | """ |
| | |
| | | |
| | | class BiCifParaformer(nn.Module): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2206.08317 |
| | | """ |
| New file |
| | |
| | | from enum import Enum |
| | | from typing import List, Tuple, Dict, Any |
| | | |
| | | import torch |
| | | from torch import nn |
| | | import math |
| | | |
| | | from funasr.models.encoder.fsmn_encoder import FSMN |
| | | from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export |
| | | |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, model, |
| | | max_seq_len=512, |
| | | feats_dim=400, |
| | | model_name='model', |
| | | **kwargs,): |
| | | super(E2EVadModel, self).__init__() |
| | | self.feats_dim = feats_dim |
| | | self.max_seq_len = max_seq_len |
| | | self.model_name = model_name |
| | | if isinstance(model.encoder, FSMN): |
| | | self.encoder = FSMN_export(model.encoder) |
| | | else: |
| | | raise "unsupported encoder" |
| | | |
| | | |
| | | def forward(self, feats: torch.Tensor, *args, ): |
| | | |
| | | scores, out_caches = self.encoder(feats, *args) |
| | | return scores, out_caches |
| | | |
| | | def get_dummy_inputs(self, frame=30): |
| | | speech = torch.randn(1, frame, self.feats_dim) |
| | | in_cache0 = torch.randn(1, 128, 19, 1) |
| | | in_cache1 = torch.randn(1, 128, 19, 1) |
| | | in_cache2 = torch.randn(1, 128, 19, 1) |
| | | in_cache3 = torch.randn(1, 128, 19, 1) |
| | | |
| | | return (speech, in_cache0, in_cache1, in_cache2, in_cache3) |
| | | |
| | | # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"): |
| | | # import numpy as np |
| | | # fbank = np.loadtxt(txt_file) |
| | | # fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) |
| | | # speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) |
| | | # speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) |
| | | # return (speech, speech_lengths) |
| | | |
| | | def get_input_names(self): |
| | | return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'speech': { |
| | | 1: 'feats_length' |
| | | }, |
| | | } |
| New file |
| | |
| | | from typing import Tuple, Dict |
| | | import copy |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from funasr.models.encoder.fsmn_encoder import BasicBlock |
| | | |
| | | class LinearTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(LinearTransform, self).__init__() |
| | | self.input_dim = input_dim |
| | | self.output_dim = output_dim |
| | | self.linear = nn.Linear(input_dim, output_dim, bias=False) |
| | | |
| | | def forward(self, input): |
| | | output = self.linear(input) |
| | | |
| | | return output |
| | | |
| | | |
| | | class AffineTransform(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(AffineTransform, self).__init__() |
| | | self.input_dim = input_dim |
| | | self.output_dim = output_dim |
| | | self.linear = nn.Linear(input_dim, output_dim) |
| | | |
| | | def forward(self, input): |
| | | output = self.linear(input) |
| | | |
| | | return output |
| | | |
| | | |
| | | class RectifiedLinear(nn.Module): |
| | | |
| | | def __init__(self, input_dim, output_dim): |
| | | super(RectifiedLinear, self).__init__() |
| | | self.dim = input_dim |
| | | self.relu = nn.ReLU() |
| | | self.dropout = nn.Dropout(0.1) |
| | | |
| | | def forward(self, input): |
| | | out = self.relu(input) |
| | | return out |
| | | |
| | | |
| | | class FSMNBlock(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | input_dim: int, |
| | | output_dim: int, |
| | | lorder=None, |
| | | rorder=None, |
| | | lstride=1, |
| | | rstride=1, |
| | | ): |
| | | super(FSMNBlock, self).__init__() |
| | | |
| | | self.dim = input_dim |
| | | |
| | | if lorder is None: |
| | | return |
| | | |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | | self.lstride = lstride |
| | | self.rstride = rstride |
| | | |
| | | self.conv_left = nn.Conv2d( |
| | | self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False) |
| | | |
| | | if self.rorder > 0: |
| | | self.conv_right = nn.Conv2d( |
| | | self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) |
| | | else: |
| | | self.conv_right = None |
| | | |
| | | def forward(self, input: torch.Tensor, cache: torch.Tensor): |
| | | x = torch.unsqueeze(input, 1) |
| | | x_per = x.permute(0, 3, 2, 1) # B D T C |
| | | |
| | | cache = cache.to(x_per.device) |
| | | y_left = torch.cat((cache, x_per), dim=2) |
| | | cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] |
| | | y_left = self.conv_left(y_left) |
| | | out = x_per + y_left |
| | | |
| | | if self.conv_right is not None: |
| | | # maybe need to check |
| | | y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | y_right = self.conv_right(y_right) |
| | | out += y_right |
| | | |
| | | out_per = out.permute(0, 3, 2, 1) |
| | | output = out_per.squeeze(1) |
| | | |
| | | return output, cache |
| | | |
| | | |
| | | class BasicBlock_export(nn.Module): |
| | | def __init__(self, |
| | | model, |
| | | ): |
| | | super(BasicBlock_export, self).__init__() |
| | | self.linear = model.linear |
| | | self.fsmn_block = model.fsmn_block |
| | | self.affine = model.affine |
| | | self.relu = model.relu |
| | | |
| | | def forward(self, input: torch.Tensor, in_cache: torch.Tensor): |
| | | x = self.linear(input) # B T D |
| | | # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) |
| | | # if cache_layer_name not in in_cache: |
| | | # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) |
| | | x, out_cache = self.fsmn_block(x, in_cache) |
| | | x = self.affine(x) |
| | | x = self.relu(x) |
| | | return x, out_cache |
| | | |
| | | |
| | | # class FsmnStack(nn.Sequential): |
| | | # def __init__(self, *args): |
| | | # super(FsmnStack, self).__init__(*args) |
| | | # |
| | | # def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]): |
| | | # x = input |
| | | # for module in self._modules.values(): |
| | | # x = module(x, in_cache) |
| | | # return x |
| | | |
| | | |
| | | ''' |
| | | FSMN net for keyword spotting |
| | | input_dim: input dimension |
| | | linear_dim: fsmn input dimensionll |
| | | proj_dim: fsmn projection dimension |
| | | lorder: fsmn left order |
| | | rorder: fsmn right order |
| | | num_syn: output dimension |
| | | fsmn_layers: no. of sequential fsmn layers |
| | | ''' |
| | | |
| | | |
| | | class FSMN(nn.Module): |
| | | def __init__( |
| | | self, model, |
| | | ): |
| | | super(FSMN, self).__init__() |
| | | |
| | | # self.input_dim = input_dim |
| | | # self.input_affine_dim = input_affine_dim |
| | | # self.fsmn_layers = fsmn_layers |
| | | # self.linear_dim = linear_dim |
| | | # self.proj_dim = proj_dim |
| | | # self.output_affine_dim = output_affine_dim |
| | | # self.output_dim = output_dim |
| | | # |
| | | # self.in_linear1 = AffineTransform(input_dim, input_affine_dim) |
| | | # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) |
| | | # self.relu = RectifiedLinear(linear_dim, linear_dim) |
| | | # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in |
| | | # range(fsmn_layers)]) |
| | | # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) |
| | | # self.out_linear2 = AffineTransform(output_affine_dim, output_dim) |
| | | # self.softmax = nn.Softmax(dim=-1) |
| | | self.in_linear1 = model.in_linear1 |
| | | self.in_linear2 = model.in_linear2 |
| | | self.relu = model.relu |
| | | # self.fsmn = model.fsmn |
| | | self.out_linear1 = model.out_linear1 |
| | | self.out_linear2 = model.out_linear2 |
| | | self.softmax = model.softmax |
| | | self.fsmn = model.fsmn |
| | | for i, d in enumerate(model.fsmn): |
| | | if isinstance(d, BasicBlock): |
| | | self.fsmn[i] = BasicBlock_export(d) |
| | | |
| | | def fuse_modules(self): |
| | | pass |
| | | |
| | | def forward( |
| | | self, |
| | | input: torch.Tensor, |
| | | *args, |
| | | ): |
| | | """ |
| | | Args: |
| | | input (torch.Tensor): Input tensor (B, T, D) |
| | | in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, |
| | | {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame |
| | | """ |
| | | |
| | | x = self.in_linear1(input) |
| | | x = self.in_linear2(x) |
| | | x = self.relu(x) |
| | | # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn |
| | | out_caches = list() |
| | | for i, d in enumerate(self.fsmn): |
| | | in_cache = args[i] |
| | | x, out_cache = d(x, in_cache) |
| | | out_caches.append(out_cache) |
| | | x = self.out_linear1(x) |
| | | x = self.out_linear2(x) |
| | | x = self.softmax(x) |
| | | |
| | | return x, out_caches |
| | | |
| | | |
| | | ''' |
| | | one deep fsmn layer |
| | | dimproj: projection dimension, input and output dimension of memory blocks |
| | | dimlinear: dimension of mapping layer |
| | | lorder: left order |
| | | rorder: right order |
| | | lstride: left stride |
| | | rstride: right stride |
| | | ''' |
| | | |
| | | |
| | | class DFSMN(nn.Module): |
| | | |
| | | def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1): |
| | | super(DFSMN, self).__init__() |
| | | |
| | | self.lorder = lorder |
| | | self.rorder = rorder |
| | | self.lstride = lstride |
| | | self.rstride = rstride |
| | | |
| | | self.expand = AffineTransform(dimproj, dimlinear) |
| | | self.shrink = LinearTransform(dimlinear, dimproj) |
| | | |
| | | self.conv_left = nn.Conv2d( |
| | | dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False) |
| | | |
| | | if rorder > 0: |
| | | self.conv_right = nn.Conv2d( |
| | | dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False) |
| | | else: |
| | | self.conv_right = None |
| | | |
| | | def forward(self, input): |
| | | f1 = F.relu(self.expand(input)) |
| | | p1 = self.shrink(f1) |
| | | |
| | | x = torch.unsqueeze(p1, 1) |
| | | x_per = x.permute(0, 3, 2, 1) |
| | | |
| | | y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) |
| | | |
| | | if self.conv_right is not None: |
| | | y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) |
| | | y_right = y_right[:, :, self.rstride:, :] |
| | | out = x_per + self.conv_left(y_left) + self.conv_right(y_right) |
| | | else: |
| | | out = x_per + self.conv_left(y_left) |
| | | |
| | | out1 = out.permute(0, 3, 2, 1) |
| | | output = input + out1.squeeze(1) |
| | | |
| | | return output |
| | | |
| | | |
| | | ''' |
| | | build stacked dfsmn layers |
| | | ''' |
| | | |
| | | |
| | | def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6): |
| | | repeats = [ |
| | | nn.Sequential( |
| | | DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) |
| | | for i in range(fsmn_layers) |
| | | ] |
| | | |
| | | return nn.Sequential(*repeats) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) |
| | | print(fsmn) |
| | | |
| | | num_params = sum(p.numel() for p in fsmn.parameters()) |
| | | print('the number of model params: {}'.format(num_params)) |
| | | x = torch.zeros(128, 200, 400) # batch-size * time * dim |
| | | y, _ = fsmn(x) # batch-size * time * dim |
| | | print('input shape: {}'.format(x.shape)) |
| | | print('output shape: {}'.format(y.shape)) |
| | | |
| | | print(fsmn.to_kaldi_net()) |
| | |
| | | from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export |
| | | |
| | | |
| | | class SANMEncoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | |
| | | } |
| | | |
| | | } |
| | | |
| | | |
| | | class SANMVadEncoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='encoder', |
| | | onnx: bool = True, |
| | | ): |
| | | super().__init__() |
| | | self.embed = model.embed |
| | | self.model = model |
| | | self.feats_dim = feats_dim |
| | | self._output_size = model._output_size |
| | | |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | if hasattr(model, 'encoders0'): |
| | | for i, d in enumerate(self.model.encoders0): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |
| | | d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) |
| | | if isinstance(d.feed_forward, PositionwiseFeedForward): |
| | | d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) |
| | | self.model.encoders0[i] = EncoderLayerSANM_export(d) |
| | | |
| | | for i, d in enumerate(self.model.encoders): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |
| | | d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn) |
| | | if isinstance(d.feed_forward, PositionwiseFeedForward): |
| | | d.feed_forward = PositionwiseFeedForward_export(d.feed_forward) |
| | | self.model.encoders[i] = EncoderLayerSANM_export(d) |
| | | |
| | | self.model_name = model_name |
| | | self.num_heads = model.encoders[0].self_attn.h |
| | | self.hidden_size = model.encoders[0].self_attn.linear_out.out_features |
| | | |
| | | def prepare_mask(self, mask, sub_masks): |
| | | mask_3d_btd = mask[:, :, None] |
| | | mask_4d_bhlt = (1 - sub_masks) * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | vad_masks: torch.Tensor, |
| | | sub_masks: torch.Tensor, |
| | | ): |
| | | speech = speech * self._output_size ** 0.5 |
| | | mask = self.make_pad_mask(speech_lengths) |
| | | vad_masks = self.prepare_mask(mask, vad_masks) |
| | | mask = self.prepare_mask(mask, sub_masks) |
| | | |
| | | if self.embed is None: |
| | | xs_pad = speech |
| | | else: |
| | | xs_pad = self.embed(speech) |
| | | |
| | | encoder_outs = self.model.encoders0(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | # encoder_outs = self.model.encoders(xs_pad, mask) |
| | | for layer_idx, encoder_layer in enumerate(self.model.encoders): |
| | | if layer_idx == len(self.model.encoders) - 1: |
| | | mask = vad_masks |
| | | encoder_outs = encoder_layer(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | | xs_pad = self.model.after_norm(xs_pad) |
| | | |
| | | return xs_pad, speech_lengths |
| | | |
| | | def get_output_size(self): |
| | | return self.model.encoders[0].size |
| | | |
| | | # def get_dummy_inputs(self): |
| | | # feats = torch.randn(1, 100, self.feats_dim) |
| | | # return (feats) |
| | | # |
| | | # def get_input_names(self): |
| | | # return ['feats'] |
| | | # |
| | | # def get_output_names(self): |
| | | # return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] |
| | | # |
| | | # def get_dynamic_axes(self): |
| | | # return { |
| | | # 'feats': { |
| | | # 1: 'feats_length' |
| | | # }, |
| | | # 'encoder_out': { |
| | | # 1: 'enc_out_length' |
| | | # }, |
| | | # 'predictor_weight': { |
| | | # 1: 'pre_out_length' |
| | | # } |
| | | # |
| | | # } |
| New file |
| | |
| | | import onnxruntime |
| | | import numpy as np |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx" |
| | | sess = onnxruntime.InferenceSession(onnx_path) |
| | | input_name = [nd.name for nd in sess.get_inputs()] |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(text_length): |
| | | return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)} |
| | | |
| | | def _run(feed_dict): |
| | | output = sess.run(output_name, input_feed=feed_dict) |
| | | for name, value in zip(output_name, output): |
| | | print('{}: {}'.format(name, value)) |
| | | _run(_get_feed_dict(10)) |
| New file |
| | |
| | | import onnxruntime |
| | | import numpy as np |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx" |
| | | sess = onnxruntime.InferenceSession(onnx_path) |
| | | input_name = [nd.name for nd in sess.get_inputs()] |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(text_length): |
| | | return {'inputs': np.ones((1, text_length), dtype=np.int64), |
| | | 'text_lengths': np.array([text_length,], dtype=np.int32), |
| | | 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), |
| | | 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | } |
| | | |
| | | def _run(feed_dict): |
| | | output = sess.run(output_name, input_feed=feed_dict) |
| | | for name, value in zip(output_name, output): |
| | | print('{}: {}'.format(name, value)) |
| | | _run(_get_feed_dict(10)) |
| New file |
| | |
| | | import onnxruntime |
| | | import numpy as np |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx" |
| | | sess = onnxruntime.InferenceSession(onnx_path) |
| | | input_name = [nd.name for nd in sess.get_inputs()] |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(feats_length): |
| | | |
| | | return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32), |
| | | 'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | 'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | 'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | 'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32), |
| | | } |
| | | |
| | | def _run(feed_dict): |
| | | output = sess.run(output_name, input_feed=feed_dict) |
| | | for name, value in zip(output_name, output): |
| | | print('{}: {}'.format(name, value.shape)) |
| | | |
| | | _run(_get_feed_dict(100)) |
| | | _run(_get_feed_dict(200)) |
| | |
| | | import torch |
| | | |
| | | from funasr.modules.scorers.scorer_interface import BatchScorerInterface |
| | | from typing import Dict |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | class AbsLM(torch.nn.Module, BatchScorerInterface, ABC): |
| | | """The abstract LM class |
| | |
| | | self, input: torch.Tensor, hidden: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class LanguageModel(AbsESPnetModel): |
| | | def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.lm = lm |
| | | self.sos = 1 |
| | | self.eos = 2 |
| | | |
| | | # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. |
| | | self.ignore_id = ignore_id |
| | | |
| | | def nll( |
| | | self, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | max_length: Optional[int] = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute negative log likelihood(nll) |
| | | |
| | | Normally, this function is called in batchify_nll. |
| | | Args: |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | max_lengths: int |
| | | """ |
| | | batch_size = text.size(0) |
| | | # For data parallel |
| | | if max_length is None: |
| | | text = text[:, : text_lengths.max()] |
| | | else: |
| | | text = text[:, :max_length] |
| | | |
| | | # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
| | | # text: (Batch, Length) -> x, y: (Batch, Length + 1) |
| | | x = F.pad(text, [1, 0], "constant", self.sos) |
| | | t = F.pad(text, [0, 1], "constant", self.ignore_id) |
| | | for i, l in enumerate(text_lengths): |
| | | t[i, l] = self.eos |
| | | x_lengths = text_lengths + 1 |
| | | |
| | | # 2. Forward Language model |
| | | # x: (Batch, Length) -> y: (Batch, Length, NVocab) |
| | | y, _ = self.lm(x, None) |
| | | |
| | | # 3. Calc negative log likelihood |
| | | # nll: (BxL,) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") |
| | | # nll: (BxL,) -> (BxL,) |
| | | if max_length is None: |
| | | nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0) |
| | | else: |
| | | nll.masked_fill_( |
| | | make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1), |
| | | 0.0, |
| | | ) |
| | | # nll: (BxL,) -> (B, L) |
| | | nll = nll.view(batch_size, -1) |
| | | return nll, x_lengths |
| | | |
| | | def batchify_nll( |
| | | self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100 |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute negative log likelihood(nll) from transformer language model |
| | | |
| | | To avoid OOM, this fuction seperate the input into batches. |
| | | Then call nll for each batch and combine and return results. |
| | | Args: |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | batch_size: int, samples each batch contain when computing nll, |
| | | you may change this to avoid OOM or increase |
| | | |
| | | """ |
| | | total_num = text.size(0) |
| | | if total_num <= batch_size: |
| | | nll, x_lengths = self.nll(text, text_lengths) |
| | | else: |
| | | nlls = [] |
| | | x_lengths = [] |
| | | max_length = text_lengths.max() |
| | | |
| | | start_idx = 0 |
| | | while True: |
| | | end_idx = min(start_idx + batch_size, total_num) |
| | | batch_text = text[start_idx:end_idx, :] |
| | | batch_text_lengths = text_lengths[start_idx:end_idx] |
| | | # batch_nll: [B * T] |
| | | batch_nll, batch_x_lengths = self.nll( |
| | | batch_text, batch_text_lengths, max_length=max_length |
| | | ) |
| | | nlls.append(batch_nll) |
| | | x_lengths.append(batch_x_lengths) |
| | | start_idx = end_idx |
| | | if start_idx == total_num: |
| | | break |
| | | nll = torch.cat(nlls) |
| | | x_lengths = torch.cat(x_lengths) |
| | | assert nll.size(0) == total_num |
| | | assert x_lengths.size(0) == total_num |
| | | return nll, x_lengths |
| | | |
| | | def forward( |
| | | self, text: torch.Tensor, text_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | nll, y_lengths = self.nll(text, text_lengths) |
| | | ntokens = y_lengths.sum() |
| | | loss = nll.sum() / ntokens |
| | | stats = dict(loss=loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, text: torch.Tensor, text_lengths: torch.Tensor |
| | | ) -> Dict[str, torch.Tensor]: |
| | | return {} |
| | |
| | | |
| | | class ContextualParaformerDecoder(ParaformerSANMDecoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | |
| | | |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) |
| | | |
| | | |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | |
| | | |
| | | class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | |
| | | for i in range(self.att_layer_num): |
| | | decoder = self.decoders[i] |
| | | c = cache[i] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, memory_mask, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | |
| | | j = i + self.att_layer_num |
| | | decoder = self.decoders2[i] |
| | | c = cache[j] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, memory_mask, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | for decoder in self.decoders3: |
| | | x, tgt_mask, memory, memory_mask, _ = decoder( |
| | | x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=None |
| | | ) |
| | | |
| | |
| | | |
| | | class ParaformerSANMDecoder(BaseTransformerDecoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | |
| | | for i in range(self.att_layer_num): |
| | | decoder = self.decoders[i] |
| | | c = cache[i] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | |
| | | j = i + self.att_layer_num |
| | | decoder = self.decoders2[i] |
| | | c = cache[j] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | for decoder in self.decoders3: |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = decoder( |
| | | x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=None |
| | | ) |
| | | |
| | |
| | | |
| | | class ParaformerDecoderSAN(BaseTransformerDecoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | |
| | | import random |
| | | import math |
| | | class MFCCA(AbsESPnetModel): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """ |
| | | Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University |
| | | MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario |
| | | https://arxiv.org/abs/2210.05265 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | |
| | | |
| | | class Paraformer(AbsESPnetModel): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2206.08317 |
| | | """ |
| | |
| | | |
| | | class ParaformerBert(Paraformer): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition |
| | | """ |
| | | |
| | |
| | | |
| | | |
| | | class DiarSondModel(AbsESPnetModel): |
| | | """Speaker overlap-aware neural diarization model |
| | | reference: https://arxiv.org/abs/2211.10243 |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | def __init__( |
| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | """ |
| | | |
| | | import logging |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | |
| | | |
| | | class TimestampPredictor(AbsESPnetModel): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | """ |
| | | |
| | | def __init__( |
| | |
| | | |
| | | class UniASR(AbsESPnetModel): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | """ |
| | | |
| | | def __init__( |
| | |
| | | |
| | | |
| | | class VADXOptions: |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__( |
| | | self, |
| | | sample_rate: int = 16000, |
| | |
| | | |
| | | |
| | | class E2EVadSpeechBufWithDoa(object): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | |
| | | |
| | | |
| | | class E2EVadFrameProb(object): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__(self): |
| | | self.noise_prob = 0.0 |
| | | self.speech_prob = 0.0 |
| | |
| | | |
| | | |
| | | class WindowDetector(object): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__(self, window_size_ms: int, sil_to_speech_time: int, |
| | | speech_to_sil_time: int, frame_size_ms: int): |
| | | self.window_size_ms = window_size_ms |
| | |
| | | |
| | | |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | self.frontend = frontend |
| | | |
| | | def AllResetDetection(self): |
| | | self.is_final = False |
| | |
| | | segment_batch = [] |
| | | if len(self.output_data_buf) > 0: |
| | | for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| | | if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ |
| | | i].contain_seg_end_point: |
| | | if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ |
| | | i].contain_seg_end_point): |
| | | continue |
| | | segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] |
| | | segment_batch.append(segment) |
| | |
| | | ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: |
| | | self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | |
| | | self.ComputeScores(feats, in_cache) |
| | | self.ComputeDecibel() |
| | | if not is_final: |
| | | self.DetectCommonFrames() |
| | | else: |
| | |
| | | |
| | | class ConvEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Convolution encoder in OpenNMT framework |
| | | """ |
| | | |
| | |
| | | |
| | | class SelfAttentionEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Self attention encoder in OpenNMT framework |
| | | """ |
| | | |
| | |
| | | tf2torch_tensor_name_prefix_torch="encoder", |
| | | tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" |
| | | ): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | """ |
| | | |
| | | super(ResNet34Diar, self).__init__( |
| | | input_size, |
| | | use_head_conv=use_head_conv, |
| | |
| | | tf2torch_tensor_name_prefix_torch="encoder", |
| | | tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" |
| | | ): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | super(ResNet34SpL2RegDiar, self).__init__( |
| | | input_size, |
| | | use_head_conv=use_head_conv, |
| | |
| | | from typeguard import check_argument_types |
| | | import numpy as np |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | |
| | | from funasr.modules.mask import subsequent_mask, vad_mask |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | |
| | | |
| | | class SANMEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | San-m: Memory equipped self-attention for end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | |
| | | |
| | | class SANMEncoderChunkOpt(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | class SANMVadEncoder(AbsEncoder): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | pos_enc_class=SinusoidalPositionEncoder, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 1, |
| | | padding_idx: int = -1, |
| | | interctc_layer_idx: List[int] = [], |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | SinusoidalPositionEncoder(), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | | if positionwise_layer_type == "linear": |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d": |
| | | positionwise_layer = MultiLayeredConv1d |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | elif positionwise_layer_type == "conv1d-linear": |
| | | positionwise_layer = Conv1dLinear |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positionwise_conv_kernel_size, |
| | | dropout_rate, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Support only linear or conv1d.") |
| | | |
| | | if selfattention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | |
| | | elif selfattention_layer_type == "sanm": |
| | | self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask |
| | | encoder_selfattn_layer_args0 = ( |
| | | attention_heads, |
| | | input_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | ) |
| | | |
| | | self.encoders0 = repeat( |
| | | 1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | input_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks-1, |
| | | lambda lnum: EncoderLayerSANM( |
| | | output_size, |
| | | output_size, |
| | | self.encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | self.interctc_layer_idx = interctc_layer_idx |
| | | if len(interctc_layer_idx) > 0: |
| | | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks |
| | | self.interctc_use_conditioning = interctc_use_conditioning |
| | | self.conditioning_layer = None |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | vad_indexes: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | | Args: |
| | | xs_pad: input tensor (B, L, D) |
| | | ilens: input length (B) |
| | | prev_states: Not to be used now. |
| | | Returns: |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0) |
| | | no_future_masks = masks & sub_masks |
| | | xs_pad *= self.output_size()**0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " + |
| | | f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | else: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | # xs_pad = self.dropout(xs_pad) |
| | | mask_tup0 = [masks, no_future_masks] |
| | | encoder_outs = self.encoders0(xs_pad, mask_tup0) |
| | | xs_pad, _ = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | | |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | if layer_idx + 1 == len(self.encoders): |
| | | # This is last layer. |
| | | coner_mask = torch.ones(masks.size(0), |
| | | masks.size(-1), |
| | | masks.size(-1), |
| | | device=xs_pad.device, |
| | | dtype=torch.bool) |
| | | for word_index, length in enumerate(ilens): |
| | | coner_mask[word_index, :, :] = vad_mask(masks.size(-1), |
| | | vad_indexes[word_index], |
| | | device=xs_pad.device) |
| | | layer_mask = masks & coner_mask |
| | | else: |
| | | layer_mask = no_future_masks |
| | | mask_tup1 = [masks, layer_mask] |
| | | encoder_outs = encoder_layer(xs_pad, mask_tup1) |
| | | xs_pad, layer_mask = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | |
| | | return cmvn |
| | | |
| | | |
| | | def apply_cmvn(inputs, cmvn_file): # noqa |
| | | def apply_cmvn(inputs, cmvn): # noqa |
| | | """ |
| | | Apply CMVN with mvn data |
| | | """ |
| | |
| | | dtype = inputs.dtype |
| | | frame, dim = inputs.shape |
| | | |
| | | cmvn = load_cmvn(cmvn_file) |
| | | means = np.tile(cmvn[0:1, :dim], (frame, 1)) |
| | | vars = np.tile(cmvn[1:2, :dim], (frame, 1)) |
| | | inputs += torch.from_numpy(means).type(dtype).to(device) |
| | |
| | | self.dither = dither |
| | | self.snip_edges = snip_edges |
| | | self.upsacle_samples = upsacle_samples |
| | | self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file) |
| | | |
| | | def output_size(self) -> int: |
| | | return self.n_mels * self.lfr_m |
| | |
| | | |
| | | if self.lfr_m != 1 or self.lfr_n != 1: |
| | | mat = apply_lfr(mat, self.lfr_m, self.lfr_n) |
| | | if self.cmvn_file is not None: |
| | | mat = apply_cmvn(mat, self.cmvn_file) |
| | | if self.cmvn is not None: |
| | | mat = apply_cmvn(mat, self.cmvn) |
| | | feat_length = mat.size(0) |
| | | feats.append(mat) |
| | | feats_lens.append(feat_length) |
| | |
| | | mat = input[i, :input_lengths[i], :] |
| | | if self.lfr_m != 1 or self.lfr_n != 1: |
| | | mat = apply_lfr(mat, self.lfr_m, self.lfr_n) |
| | | if self.cmvn_file is not None: |
| | | mat = apply_cmvn(mat, self.cmvn_file) |
| | | if self.cmvn is not None: |
| | | mat = apply_cmvn(mat, self.cmvn) |
| | | feat_length = mat.size(0) |
| | | feats.append(mat) |
| | | feats_lens.append(feat_length) |
| | |
| | | last_fire_place = len_time - 1
|
| | | last_fire_remainds = 0.0
|
| | | pre_alphas_length = 0
|
| | | last_fire = False
|
| | |
|
| | | mask_chunk_peak_predictor = None
|
| | | if cache is not None:
|
| | |
| | | if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
|
| | | last_fire_place = len_time - 1 - i
|
| | | last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
|
| | | last_fire = True
|
| | | break
|
| | | last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
|
| | | cache["cif_hidden"] = hidden[:, last_fire_place:, :]
|
| | | cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
|
| | | if last_fire:
|
| | | last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
|
| | | cache["cif_hidden"] = hidden[:, last_fire_place:, :]
|
| | | cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
|
| | | else:
|
| | | cache["cif_hidden"] = hidden
|
| | | cache["cif_alphas"] = alphas
|
| | | token_num_int = token_num.floor().type(torch.int32).item()
|
| | | return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
|
| | |
|
| File was renamed from funasr/punctuation/target_delay_transformer.py |
| | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder |
| | | from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder |
| | | #from funasr.modules.mask import subsequent_n_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class TargetDelayTransformer(AbsPunctuation): |
| | | |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| File was renamed from funasr/punctuation/vad_realtime_transformer.py |
| | |
| | | import torch.nn as nn |
| | | |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | |
| | | |
| | | class VadRealtimeTransformer(AbsPunctuation): |
| | | |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | |
| | | |
| | | class overlap_chunk(): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | San-m: Memory equipped self-attention for end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | |
| | | "${_target}.cc") |
| | | target_link_libraries(${_target} |
| | | rg_grpc_proto |
| | | rapidasr |
| | | funasr |
| | | ${EXTRA_LIBS} |
| | | ${_REFLECTION} |
| | | ${_GRPC_GRPCPP} |
| | |
| | | python grpc_main_client_mic.py --host $server_ip --port 10108 |
| | | ``` |
| | | |
| | | The `grpc_main_client_mic.py` follows the [original design] (https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc#workflow-in-desgin) by sending audio_data with chunks. If you want to send audio_data in one request, here is an example: |
| | | |
| | | ``` |
| | | # go to ../python/grpc to find this package |
| | | import paraformer_pb2 |
| | | |
| | | |
| | | class RecognizeStub: |
| | | def __init__(self, channel): |
| | | self.Recognize = channel.stream_stream( |
| | | '/paraformer.ASR/Recognize', |
| | | request_serializer=paraformer_pb2.Request.SerializeToString, |
| | | response_deserializer=paraformer_pb2.Response.FromString, |
| | | ) |
| | | |
| | | |
| | | async def send(channel, data, speaking, isEnd): |
| | | stub = RecognizeStub(channel) |
| | | req = paraformer_pb2.Request() |
| | | if data: |
| | | req.audio_data = data |
| | | req.user = 'zz' |
| | | req.language = 'zh-CN' |
| | | req.speaking = speaking |
| | | req.isEnd = isEnd |
| | | q = queue.SimpleQueue() |
| | | q.put(req) |
| | | return stub.Recognize(iter(q.get, None)) |
| | | |
| | | # send the audio data once |
| | | async def grpc_rec(data, grpc_uri): |
| | | with grpc.insecure_channel(grpc_uri) as channel: |
| | | b = time.time() |
| | | response = await send(channel, data, False, False) |
| | | resp = response.next() |
| | | text = '' |
| | | if 'decoding' == resp.action: |
| | | resp = response.next() |
| | | if 'finish' == resp.action: |
| | | text = json.loads(resp.sentence)['text'] |
| | | response = await send(channel, None, False, True) |
| | | return { |
| | | 'text': text, |
| | | 'time': time.time() - b, |
| | | } |
| | | |
| | | async def test(): |
| | | # fc = FunAsrGrpcClient('127.0.0.1', 9900) |
| | | # t = await fc.rec(wav.tobytes()) |
| | | # print(t) |
| | | wav, _ = sf.read('z-10s.wav', dtype='int16') |
| | | uri = '127.0.0.1:9900' |
| | | res = await grpc_rec(wav.tobytes(), uri) |
| | | print(res) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | asyncio.run(test()) |
| | | |
| | | ``` |
| | | |
| | | |
| | | ## Acknowledge |
| | | 1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR). |
| | | 2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service. |
| | |
| | | #include "paraformer.grpc.pb.h" |
| | | #include "paraformer_server.h" |
| | | |
| | | |
| | | using grpc::Server; |
| | | using grpc::ServerBuilder; |
| | | using grpc::ServerContext; |
| | |
| | | using grpc::ServerWriter; |
| | | using grpc::Status; |
| | | |
| | | |
| | | using paraformer::Request; |
| | | using paraformer::Response; |
| | | using paraformer::ASR; |
| | | |
| | | ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) { |
| | | AsrHanlde=RapidAsrInit(model_path, thread_num, quantize); |
| | | AsrHanlde=FunASRInit(model_path, thread_num, quantize); |
| | | std::cout << "ASRServicer init" << std::endl; |
| | | init_flag = 0; |
| | | } |
| | | |
| | | void ASRServicer::clear_states(const std::string& user) { |
| | | clear_buffers(user); |
| | | clear_transcriptions(user); |
| | | } |
| | | |
| | | void ASRServicer::clear_buffers(const std::string& user) { |
| | | if (client_buffers.count(user)) { |
| | | client_buffers.erase(user); |
| | | } |
| | | } |
| | | |
| | | void ASRServicer::clear_transcriptions(const std::string& user) { |
| | | if (client_transcription.count(user)) { |
| | | client_transcription.erase(user); |
| | | } |
| | | } |
| | | |
| | | void ASRServicer::disconnect(const std::string& user) { |
| | | clear_states(user); |
| | | std::cout << "Disconnecting user: " << user << std::endl; |
| | | } |
| | | |
| | | grpc::Status ASRServicer::Recognize( |
| | |
| | | grpc::ServerReaderWriter<Response, Request>* stream) { |
| | | |
| | | Request req; |
| | | std::unordered_map<std::string, std::string> client_buffers; |
| | | std::unordered_map<std::string, std::string> client_transcription; |
| | | |
| | | while (stream->Read(&req)) { |
| | | if (req.isend()) { |
| | | std::cout << "asr end" << std::endl; |
| | | disconnect(req.user()); |
| | | // disconnect |
| | | if (client_buffers.count(req.user())) { |
| | | client_buffers.erase(req.user()); |
| | | } |
| | | if (client_transcription.count(req.user())) { |
| | | client_transcription.erase(req.user()); |
| | | } |
| | | |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "asr end"})" |
| | |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | } else if (!req.speaking()) { |
| | | if (client_buffers.count(req.user()) == 0) { |
| | | if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) { |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "waiting_for_voice"})" |
| | |
| | | stream->Write(res); |
| | | }else { |
| | | auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); |
| | | std::string tmp_data = this->client_buffers[req.user()]; |
| | | this->clear_states(req.user()); |
| | | |
| | | if (req.audio_data().size() > 0) { |
| | | auto& buf = client_buffers[req.user()]; |
| | | buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end()); |
| | | } |
| | | std::string tmp_data = client_buffers[req.user()]; |
| | | // clear_states |
| | | if (client_buffers.count(req.user())) { |
| | | client_buffers.erase(req.user()); |
| | | } |
| | | if (client_transcription.count(req.user())) { |
| | | client_transcription.erase(req.user()); |
| | | } |
| | | |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})" |
| | | ); |
| | | int data_len_int = tmp_data.length(); |
| | | int data_len_int = tmp_data.length(); |
| | | std::string data_len = std::to_string(data_len_int); |
| | | std::stringstream ss; |
| | | ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})"; |
| | |
| | | res.set_user(req.user()); |
| | | res.set_action("finish"); |
| | | res.set_language(req.language()); |
| | | |
| | | |
| | | |
| | | stream->Write(res); |
| | | } |
| | | else { |
| | | RPASR_RESULT Result= RapidAsrRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL); |
| | | std::string asr_result = ((RPASR_RECOG_RESULT*)Result)->msg; |
| | | FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL); |
| | | std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg; |
| | | |
| | | auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); |
| | | std::string delay_str = std::to_string(end_time - begin_time); |
| | | |
| | | |
| | | std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl; |
| | | Response res; |
| | | std::stringstream ss; |
| | |
| | | res.set_user(req.user()); |
| | | res.set_action("finish"); |
| | | res.set_language(req.language()); |
| | | |
| | | |
| | | |
| | | stream->Write(res); |
| | | } |
| | | } |
| | |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | } |
| | | } |
| | | } |
| | | return Status::OK; |
| | | } |
| | | |
| | | |
| | | void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) { |
| | | std::string server_address; |
| | |
| | | #include <chrono> |
| | | |
| | | #include "paraformer.grpc.pb.h" |
| | | #include "librapidasrapi.h" |
| | | #include "libfunasrapi.h" |
| | | |
| | | |
| | | using grpc::Server; |
| | |
| | | { |
| | | std::string msg; |
| | | float snippet_time; |
| | | }RPASR_RECOG_RESULT; |
| | | }FUNASR_RECOG_RESULT; |
| | | |
| | | |
| | | class ASRServicer final : public ASR::Service { |
| | | private: |
| | | int init_flag; |
| | | std::unordered_map<std::string, std::string> client_buffers; |
| | | std::unordered_map<std::string, std::string> client_transcription; |
| | | |
| | | public: |
| | | ASRServicer(const char* model_path, int thread_num, bool quantize); |
| | | void clear_states(const std::string& user); |
| | | void clear_buffers(const std::string& user); |
| | | void clear_transcriptions(const std::string& user); |
| | | void disconnect(const std::string& user); |
| | | grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream); |
| | | RPASR_HANDLE AsrHanlde; |
| | | FUNASR_HANDLE AsrHanlde; |
| | | |
| | | }; |
| | |
| | | |
| | | project(FunASRonnx) |
| | | |
| | | set(CMAKE_CXX_STANDARD 11) |
| | | # set(CMAKE_CXX_STANDARD 11) |
| | | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") |
| | | set(CMAKE_POSITION_INDEPENDENT_CODE ON) |
| | | |
| | | include(TestBigEndian) |
| | | test_big_endian(BIG_ENDIAN) |
| | | if(BIG_ENDIAN) |
| | | message("Big endian system") |
| | | else() |
| | | message("Little endian system") |
| | | endif() |
| | | |
| | | # for onnxruntime |
| | | |
| | | IF(WIN32) |
| | | |
| | | |
| | | if(CMAKE_CL_64) |
| | | link_directories(${ONNXRUNTIME_DIR}\\lib) |
| | | else() |
| | | add_definitions(-D_WIN_X86) |
| | | endif() |
| | | ELSE() |
| | | |
| | | |
| | | link_directories(${ONNXRUNTIME_DIR}/lib) |
| | | |
| | | link_directories(${ONNXRUNTIME_DIR}/lib) |
| | | endif() |
| | | |
| | | add_subdirectory("./third_party/yaml-cpp") |
| | |
| | | #include <queue> |
| | | #include <stdint.h> |
| | | |
| | | #ifndef model_sample_rate |
| | | #define model_sample_rate 16000 |
| | | #endif |
| | | #ifndef WAV_HEADER_SIZE |
| | | #define WAV_HEADER_SIZE 44 |
| | | #endif |
| | | |
| | | using namespace std; |
| | | |
| | | class AudioFrame { |
| | |
| | | int16_t *speech_buff; |
| | | int speech_len; |
| | | int speech_align_len; |
| | | int16_t sample_rate; |
| | | int offset; |
| | | float align_size; |
| | | int data_type; |
| | |
| | | Audio(int data_type, int size); |
| | | ~Audio(); |
| | | void disp(); |
| | | bool loadwav(const char* filename); |
| | | bool loadwav(const char* buf, int nLen); |
| | | bool loadpcmwav(const char* buf, int nFileLen); |
| | | bool loadpcmwav(const char* filename); |
| | | bool loadwav(const char* filename, int32_t* sampling_rate); |
| | | void wavResample(int32_t sampling_rate, const float *waveform, int32_t n); |
| | | bool loadwav(const char* buf, int nLen, int32_t* sampling_rate); |
| | | bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate); |
| | | bool loadpcmwav(const char* filename, int32_t* sampling_rate); |
| | | int fetch_chunck(float *&dout, int len); |
| | | int fetch(float *&dout, int &len, int &flag); |
| | | void padding(); |
| New file |
| | |
| | | #pragma once |
| | | |
| | | #ifdef WIN32 |
| | | #ifdef _FUNASR_API_EXPORT |
| | | #define _FUNASRAPI __declspec(dllexport) |
| | | #else |
| | | #define _FUNASRAPI __declspec(dllimport) |
| | | #endif |
| | | #else |
| | | #define _FUNASRAPI |
| | | #endif |
| | | |
| | | #ifndef _WIN32 |
| | | #define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__)) |
| | | #else |
| | | #define FUNASR_CALLBCK_PREFIX __stdcall |
| | | #endif |
| | | |
| | | #ifdef __cplusplus |
| | | |
| | | extern "C" { |
| | | #endif |
| | | |
| | | typedef void* FUNASR_HANDLE; |
| | | typedef void* FUNASR_RESULT; |
| | | typedef unsigned char FUNASR_BOOL; |
| | | |
| | | #define FUNASR_TRUE 1 |
| | | #define FUNASR_FALSE 0 |
| | | #define QM_DEFAULT_THREAD_NUM 4 |
| | | |
| | | typedef enum |
| | | { |
| | | RASR_NONE=-1, |
| | | RASRM_CTC_GREEDY_SEARCH=0, |
| | | RASRM_CTC_RPEFIX_BEAM_SEARCH = 1, |
| | | RASRM_ATTENSION_RESCORING = 2, |
| | | |
| | | }FUNASR_MODE; |
| | | |
| | | typedef enum { |
| | | FUNASR_MODEL_PADDLE = 0, |
| | | FUNASR_MODEL_PADDLE_2 = 1, |
| | | FUNASR_MODEL_K2 = 2, |
| | | FUNASR_MODEL_PARAFORMER = 3, |
| | | |
| | | }FUNASR_MODEL_TYPE; |
| | | |
| | | typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step. |
| | | |
| | | // APIs for qmasr |
| | | _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize); |
| | | |
| | | |
| | | // if not give a fnCallback ,it should be NULL |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback); |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback); |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback); |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback); |
| | | |
| | | _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex); |
| | | |
| | | _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result); |
| | | |
| | | _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result); |
| | | |
| | | _FUNASRAPI void FunASRUninit(FUNASR_HANDLE Handle); |
| | | |
| | | _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result); |
| | | |
| | | #ifdef __cplusplus |
| | | |
| | | } |
| | | #endif |
| | |
| | | |
| | | ## 快速使用 |
| | | |
| | | ### Windows |
| | | |
| | | 安装Vs2022 打开cpp_onnx目录下的cmake工程,直接 build即可。 本仓库已经准备好所有相关依赖库。 |
| | | |
| | | Windows下已经预置fftw3及onnxruntime库 |
| | | |
| | | ### Linux |
| | | See the bottom of this page: Building Guidance |
| | | |
| | | ### 运行程序 |
| | | |
| | | tester /path/to/models_dir /path/to/wave_file quantize(true or false) |
| | | |
| | | 例如: tester /data/models /data/test.wav false |
| | | |
| | | /data/models 需要包括如下三个文件: config.yaml, am.mvn, model.onnx(or model_quant.onnx) |
| | | |
| | | ## 支持平台 |
| | | - Windows |
| | | - Linux/Unix |
| | | |
| | | ## 依赖 |
| | | - fftw3 |
| | | - openblas |
| | | - onnxruntime |
| | | |
| | | ## 导出onnx格式模型文件 |
| | | 安装 modelscope与FunASR,依赖:torch,torchaudio,安装过程[详细参考文档](https://github.com/alibaba-damo-academy/FunASR/wiki) |
| | | ## Demo |
| | | ```shell |
| | | pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html |
| | | git clone https://github.com/alibaba/FunASR.git && cd FunASR |
| | | pip install --editable ./ |
| | | tester /path/models_dir /path/wave_file quantize(true or false) |
| | | ``` |
| | | 导出onnx模型,[详见](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export),参考示例,从modelscope中模型导出: |
| | | |
| | | The structure of /path/models_dir |
| | | ``` |
| | | config.yaml, am.mvn, model.onnx(or model_quant.onnx) |
| | | ``` |
| | | |
| | | ## Steps |
| | | |
| | | ### Export onnx |
| | | #### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation) |
| | | |
| | | ```shell |
| | | pip3 install torch torchaudio |
| | | pip install -U modelscope |
| | | pip install -U funasr |
| | | ``` |
| | | #### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) |
| | | |
| | | ```shell |
| | | python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True |
| | | ``` |
| | | |
| | | ## Building Guidance for Linux/Unix |
| | | ### Building for Linux/Unix |
| | | |
| | | ``` |
| | | git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime |
| | | mkdir build |
| | | cd build |
| | | #### Download onnxruntime |
| | | ```shell |
| | | # download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0 |
| | | # here we get a copy of onnxruntime for linux 64 |
| | | wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz |
| | | tar -zxvf onnxruntime-linux-x64-1.14.0.tgz |
| | | # ls |
| | | # onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz |
| | | ``` |
| | | |
| | | #install fftw3-dev |
| | | ubuntu: apt install libfftw3-dev |
| | | centos: yum install fftw fftw-devel |
| | | #### Install fftw3 |
| | | ```shell |
| | | sudo apt install libfftw3-dev #ubuntu |
| | | # sudo yum install fftw fftw-devel #centos |
| | | ``` |
| | | |
| | | #install openblas |
| | | bash ./third_party/install_openblas.sh |
| | | #### Install openblas |
| | | ```shell |
| | | sudo apt-get install libopenblas-dev #ubuntu |
| | | # sudo yum -y install openblas-devel #centos |
| | | ``` |
| | | |
| | | # build |
| | | cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 |
| | | make |
| | | #### Build runtime |
| | | ```shell |
| | | git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime |
| | | mkdir build && cd build |
| | | cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 |
| | | make |
| | | ``` |
| | | |
| | | # then in the subfolder tester of current direcotry, you will see a program, tester |
| | | |
| | | ```` |
| | | |
| | | ### The structure of a qualified onnxruntime package. |
| | | #### The structure of a qualified onnxruntime package. |
| | | ``` |
| | | onnxruntime_xxx |
| | | ├───include |
| | | └───lib |
| | | ``` |
| | | |
| | | ## 注意 |
| | | 本程序只支持 采样率16000hz, 位深16bit的 **单声道** 音频。 |
| | | ### Building for Windows |
| | | |
| | | Ref to win/ |
| | | |
| | | ## Acknowledge |
| | | 1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR). |
| | |
| | | #include <stdio.h> |
| | | #include <stdlib.h> |
| | | #include <string.h> |
| | | #include <fstream> |
| | | #include <assert.h> |
| | | |
| | | #include "Audio.h" |
| | | #include "precomp.h" |
| | | |
| | | using namespace std; |
| | | |
| | | // see http://soundfile.sapp.org/doc/WaveFormat/ |
| | | // Note: We assume little endian here |
| | | struct WaveHeader { |
| | | bool Validate() const { |
| | | // F F I R |
| | | if (chunk_id != 0x46464952) { |
| | | printf("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id); |
| | | return false; |
| | | } |
| | | // E V A W |
| | | if (format != 0x45564157) { |
| | | printf("Expected format WAVE. Given: 0x%08x\n", format); |
| | | return false; |
| | | } |
| | | |
| | | if (subchunk1_id != 0x20746d66) { |
| | | printf("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n", |
| | | subchunk1_id); |
| | | return false; |
| | | } |
| | | |
| | | if (subchunk1_size != 16) { // 16 for PCM |
| | | printf("Expected subchunk1_size 16. Given: %d\n", |
| | | subchunk1_size); |
| | | return false; |
| | | } |
| | | |
| | | if (audio_format != 1) { // 1 for PCM |
| | | printf("Expected audio_format 1. Given: %d\n", audio_format); |
| | | return false; |
| | | } |
| | | |
| | | if (num_channels != 1) { // we support only single channel for now |
| | | printf("Expected single channel. Given: %d\n", num_channels); |
| | | return false; |
| | | } |
| | | if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { |
| | | return false; |
| | | } |
| | | |
| | | if (block_align != (num_channels * bits_per_sample / 8)) { |
| | | return false; |
| | | } |
| | | |
| | | if (bits_per_sample != 16) { // we support only 16 bits per sample |
| | | printf("Expected bits_per_sample 16. Given: %d\n", |
| | | bits_per_sample); |
| | | return false; |
| | | } |
| | | return true; |
| | | } |
| | | |
| | | // See https://en.wikipedia.org/wiki/WAV#Metadata and |
| | | // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf |
| | | void SeekToDataChunk(std::istream &is) { |
| | | // a t a d |
| | | while (is && subchunk2_id != 0x61746164) { |
| | | // const char *p = reinterpret_cast<const char *>(&subchunk2_id); |
| | | // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], |
| | | // p[1], p[2], p[3], subchunk2_size); |
| | | is.seekg(subchunk2_size, std::istream::cur); |
| | | is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t)); |
| | | is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t)); |
| | | } |
| | | } |
| | | |
| | | int32_t chunk_id; |
| | | int32_t chunk_size; |
| | | int32_t format; |
| | | int32_t subchunk1_id; |
| | | int32_t subchunk1_size; |
| | | int16_t audio_format; |
| | | int16_t num_channels; |
| | | int32_t sample_rate; |
| | | int32_t byte_rate; |
| | | int16_t block_align; |
| | | int16_t bits_per_sample; |
| | | int32_t subchunk2_id; // a tag of this chunk |
| | | int32_t subchunk2_size; // size of subchunk2 |
| | | }; |
| | | static_assert(sizeof(WaveHeader) == WAV_HEADER_SIZE, ""); |
| | | |
| | | class AudioWindow { |
| | | private: |
| | |
| | | float frame_length = 400; |
| | | float frame_shift = 160; |
| | | float num_new_samples = |
| | | ceil((num_samples - 400) / frame_shift) * frame_shift + frame_length; |
| | | ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length; |
| | | |
| | | end = start + num_new_samples; |
| | | len = (int)num_new_samples; |
| | |
| | | |
| | | void Audio::disp() |
| | | { |
| | | printf("Audio time is %f s. len is %d\n", (float)speech_len / 16000, |
| | | printf("Audio time is %f s. len is %d\n", (float)speech_len / model_sample_rate, |
| | | speech_len); |
| | | } |
| | | |
| | | float Audio::get_time_len() |
| | | { |
| | | return (float)speech_len / 16000; |
| | | //speech_len); |
| | | return (float)speech_len / model_sample_rate; |
| | | } |
| | | |
| | | bool Audio::loadwav(const char *filename) |
| | | void Audio::wavResample(int32_t sampling_rate, const float *waveform, |
| | | int32_t n) |
| | | { |
| | | printf( |
| | | "Creating a resampler:\n" |
| | | " in_sample_rate: %d\n" |
| | | " output_sample_rate: %d\n", |
| | | sampling_rate, static_cast<int32_t>(model_sample_rate)); |
| | | float min_freq = |
| | | std::min<int32_t>(sampling_rate, model_sample_rate); |
| | | float lowpass_cutoff = 0.99 * 0.5 * min_freq; |
| | | |
| | | int32_t lowpass_filter_width = 6; |
| | | //FIXME |
| | | //auto resampler = new LinearResample( |
| | | // sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width); |
| | | auto resampler = std::make_unique<LinearResample>( |
| | | sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width); |
| | | std::vector<float> samples; |
| | | resampler->Resample(waveform, n, true, &samples); |
| | | //reset speech_data |
| | | speech_len = samples.size(); |
| | | if (speech_data != NULL) { |
| | | free(speech_data); |
| | | } |
| | | speech_data = (float*)malloc(sizeof(float) * speech_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_len); |
| | | copy(samples.begin(), samples.end(), speech_data); |
| | | } |
| | | |
| | | bool Audio::loadwav(const char *filename, int32_t* sampling_rate) |
| | | { |
| | | WaveHeader header; |
| | | if (speech_data != NULL) { |
| | | free(speech_data); |
| | | } |
| | | if (speech_buff != NULL) { |
| | | free(speech_buff); |
| | | } |
| | | |
| | | |
| | | offset = 0; |
| | | |
| | | FILE *fp; |
| | | fp = fopen(filename, "rb"); |
| | | if (fp == nullptr) |
| | | std::ifstream is(filename, std::ifstream::binary); |
| | | is.read(reinterpret_cast<char *>(&header), sizeof(header)); |
| | | if(!is){ |
| | | fprintf(stderr, "Failed to read %s\n", filename); |
| | | return false; |
| | | fseek(fp, 0, SEEK_END); /*定位到文件末尾*/ |
| | | uint32_t nFileLen = ftell(fp); /*得到文件大小*/ |
| | | fseek(fp, 44, SEEK_SET); /*跳过wav文件头*/ |
| | | |
| | | speech_len = (nFileLen - 44) / 2; |
| | | speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size); |
| | | speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_align_len); |
| | | } |
| | | |
| | | *sampling_rate = header.sample_rate; |
| | | // header.subchunk2_size contains the number of bytes in the data. |
| | | // As we assume each sample contains two bytes, so it is divided by 2 here |
| | | speech_len = header.subchunk2_size / 2; |
| | | speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len); |
| | | |
| | | if (speech_buff) |
| | | { |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_align_len); |
| | | int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp); |
| | | fclose(fp); |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_len); |
| | | is.read(reinterpret_cast<char *>(speech_buff), header.subchunk2_size); |
| | | if (!is) { |
| | | fprintf(stderr, "Failed to read %s\n", filename); |
| | | return false; |
| | | } |
| | | speech_data = (float*)malloc(sizeof(float) * speech_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_len); |
| | | |
| | | speech_data = (float*)malloc(sizeof(float) * speech_align_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_align_len); |
| | | int i; |
| | | float scale = 1; |
| | | |
| | | if (data_type == 1) { |
| | | scale = 32768; |
| | | } |
| | | |
| | | for (i = 0; i < speech_len; i++) { |
| | | for (int32_t i = 0; i != speech_len; ++i) { |
| | | speech_data[i] = (float)speech_buff[i] / scale; |
| | | } |
| | | |
| | | //resample |
| | | if(*sampling_rate != model_sample_rate){ |
| | | wavResample(*sampling_rate, speech_data, speech_len); |
| | | } |
| | | |
| | | AudioFrame* frame = new AudioFrame(speech_len); |
| | | frame_queue.push(frame); |
| | | |
| | | |
| | | return true; |
| | | } |
| | |
| | | return false; |
| | | } |
| | | |
| | | |
| | | bool Audio::loadwav(const char* buf, int nFileLen) |
| | | bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate) |
| | | { |
| | | |
| | | |
| | | |
| | | WaveHeader header; |
| | | if (speech_data != NULL) { |
| | | free(speech_data); |
| | | } |
| | | if (speech_buff != NULL) { |
| | | free(speech_buff); |
| | | } |
| | | |
| | | offset = 0; |
| | | |
| | | size_t nOffset = 0; |
| | | std::memcpy(&header, buf, sizeof(header)); |
| | | |
| | | #define WAV_HEADER_SIZE 44 |
| | | |
| | | speech_len = (nFileLen - WAV_HEADER_SIZE) / 2; |
| | | speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size); |
| | | speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len); |
| | | *sampling_rate = header.sample_rate; |
| | | speech_len = header.subchunk2_size / 2; |
| | | speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len); |
| | | if (speech_buff) |
| | | { |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_align_len); |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_len); |
| | | memcpy((void*)speech_buff, (const void*)(buf + WAV_HEADER_SIZE), speech_len * sizeof(int16_t)); |
| | | |
| | | speech_data = (float*)malloc(sizeof(float) * speech_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_len); |
| | | |
| | | speech_data = (float*)malloc(sizeof(float) * speech_align_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_align_len); |
| | | int i; |
| | | float scale = 1; |
| | | |
| | | if (data_type == 1) { |
| | | scale = 32768; |
| | | } |
| | | |
| | | for (i = 0; i < speech_len; i++) { |
| | | for (int32_t i = 0; i != speech_len; ++i) { |
| | | speech_data[i] = (float)speech_buff[i] / scale; |
| | | } |
| | | |
| | | //resample |
| | | if(*sampling_rate != model_sample_rate){ |
| | | wavResample(*sampling_rate, speech_data, speech_len); |
| | | } |
| | | |
| | | AudioFrame* frame = new AudioFrame(speech_len); |
| | | frame_queue.push(frame); |
| | | |
| | | return true; |
| | | } |
| | | else |
| | | return false; |
| | | |
| | | } |
| | | |
| | | |
| | | bool Audio::loadpcmwav(const char* buf, int nBufLen) |
| | | bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate) |
| | | { |
| | | if (speech_data != NULL) { |
| | | free(speech_data); |
| | |
| | | } |
| | | offset = 0; |
| | | |
| | | size_t nOffset = 0; |
| | | |
| | | |
| | | |
| | | speech_len = nBufLen / 2; |
| | | speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size); |
| | | speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len); |
| | | speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len); |
| | | if (speech_buff) |
| | | { |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_align_len); |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_len); |
| | | memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t)); |
| | | |
| | | speech_data = (float*)malloc(sizeof(float) * speech_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_len); |
| | | |
| | | speech_data = (float*)malloc(sizeof(float) * speech_align_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_align_len); |
| | | |
| | | |
| | | int i; |
| | | float scale = 1; |
| | | |
| | | if (data_type == 1) { |
| | | scale = 32768; |
| | | } |
| | | |
| | | for (i = 0; i < speech_len; i++) { |
| | | for (int32_t i = 0; i != speech_len; ++i) { |
| | | speech_data[i] = (float)speech_buff[i] / scale; |
| | | } |
| | | |
| | | //resample |
| | | if(*sampling_rate != model_sample_rate){ |
| | | wavResample(*sampling_rate, speech_data, speech_len); |
| | | } |
| | | |
| | | AudioFrame* frame = new AudioFrame(speech_len); |
| | |
| | | } |
| | | else |
| | | return false; |
| | | |
| | | |
| | | } |
| | | |
| | | bool Audio::loadpcmwav(const char* filename) |
| | | bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate) |
| | | { |
| | | |
| | | if (speech_data != NULL) { |
| | | free(speech_data); |
| | | } |
| | |
| | | fseek(fp, 0, SEEK_SET); |
| | | |
| | | speech_len = (nFileLen) / 2; |
| | | speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size); |
| | | speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len); |
| | | speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len); |
| | | if (speech_buff) |
| | | { |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_align_len); |
| | | memset(speech_buff, 0, sizeof(int16_t) * speech_len); |
| | | int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp); |
| | | fclose(fp); |
| | | |
| | | speech_data = (float*)malloc(sizeof(float) * speech_align_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_align_len); |
| | | speech_data = (float*)malloc(sizeof(float) * speech_len); |
| | | memset(speech_data, 0, sizeof(float) * speech_len); |
| | | |
| | | |
| | | |
| | | int i; |
| | | float scale = 1; |
| | | |
| | | if (data_type == 1) { |
| | | scale = 32768; |
| | | } |
| | | |
| | | for (i = 0; i < speech_len; i++) { |
| | | for (int32_t i = 0; i != speech_len; ++i) { |
| | | speech_data[i] = (float)speech_buff[i] / scale; |
| | | } |
| | | |
| | | //resample |
| | | if(*sampling_rate != model_sample_rate){ |
| | | wavResample(*sampling_rate, speech_data, speech_len); |
| | | } |
| | | |
| | | AudioFrame* frame = new AudioFrame(speech_len); |
| | | frame_queue.push(frame); |
| | | |
| | | |
| | | return true; |
| | | } |
| | |
| | | return false; |
| | | |
| | | } |
| | | |
| | | |
| | | int Audio::fetch_chunck(float *&dout, int len) |
| | | { |
| | |
| | | |
| | | file(GLOB files1 "*.cpp") |
| | | file(GLOB files2 "*.cc") |
| | | file(GLOB files4 "paraformer/*.cpp") |
| | | |
| | | set(files ${files1} ${files2} ${files3} ${files4}) |
| | | |
| | | # message("${files}") |
| | | |
| | | add_library(rapidasr ${files}) |
| | | add_library(funasr ${files}) |
| | | |
| | | if(WIN32) |
| | | |
| | | set(EXTRA_LIBS libfftw3f-3 yaml-cpp) |
| | | if(CMAKE_CL_64) |
| | | target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64) |
| | | target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64) |
| | | else() |
| | | target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86) |
| | | target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86) |
| | | endif() |
| | | target_include_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include ) |
| | | target_include_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include ) |
| | | |
| | | target_compile_definitions(rapidasr PUBLIC -D_RPASR_API_EXPORT) |
| | | target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT) |
| | | else() |
| | | |
| | | set(EXTRA_LIBS fftw3f pthread yaml-cpp) |
| | | target_include_directories(rapidasr PUBLIC "/usr/local/opt/fftw/include") |
| | | target_link_directories(rapidasr PUBLIC "/usr/local/opt/fftw/lib") |
| | | target_include_directories(funasr PUBLIC "/usr/local/opt/fftw/include") |
| | | target_link_directories(funasr PUBLIC "/usr/local/opt/fftw/lib") |
| | | |
| | | target_include_directories(rapidasr PUBLIC "/usr/local/opt/openblas/include") |
| | | target_link_directories(rapidasr PUBLIC "/usr/local/opt/openblas/lib") |
| | | target_include_directories(funasr PUBLIC "/usr/local/opt/openblas/include") |
| | | target_link_directories(funasr PUBLIC "/usr/local/opt/openblas/lib") |
| | | |
| | | target_include_directories(rapidasr PUBLIC "/usr/include") |
| | | target_link_directories(rapidasr PUBLIC "/usr/lib64") |
| | | target_include_directories(funasr PUBLIC "/usr/include") |
| | | target_link_directories(funasr PUBLIC "/usr/lib64") |
| | | |
| | | target_include_directories(rapidasr PUBLIC ${FFTW3F_INCLUDE_DIR}) |
| | | target_link_directories(rapidasr PUBLIC ${FFTW3F_LIBRARY_DIR}) |
| | | target_include_directories(funasr PUBLIC ${FFTW3F_INCLUDE_DIR}) |
| | | target_link_directories(funasr PUBLIC ${FFTW3F_LIBRARY_DIR}) |
| | | include_directories(${ONNXRUNTIME_DIR}/include) |
| | | endif() |
| | | |
| | | include_directories(${CMAKE_SOURCE_DIR}/include) |
| | | target_link_libraries(rapidasr PUBLIC onnxruntime ${EXTRA_LIBS}) |
| | | target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS}) |
| | | |
| | | |
| | | |
| | |
| | | |
| | | FeatureExtract::FeatureExtract(int mode) : mode(mode) |
| | | { |
| | | fftw_init(); |
| | | } |
| | | |
| | | FeatureExtract::~FeatureExtract() |
| | | { |
| | | fftwf_free(fft_input); |
| | | fftwf_free(fft_out); |
| | | fftwf_destroy_plan(p); |
| | | } |
| | | |
| | | void FeatureExtract::reset() |
| | |
| | | return fqueue.size(); |
| | | } |
| | | |
| | | void FeatureExtract::fftw_init() |
| | | void FeatureExtract::insert(fftwf_plan plan, float *din, int len, int flag) |
| | | { |
| | | int fft_size = 512; |
| | | fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size); |
| | | fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size); |
| | | float* fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size); |
| | | fftwf_complex* fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size); |
| | | memset(fft_input, 0, sizeof(float) * fft_size); |
| | | p = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE); |
| | | } |
| | | |
| | | void FeatureExtract::insert(float *din, int len, int flag) |
| | | { |
| | | const float *window = (const float *)&window_hex; |
| | | if (mode == 3) |
| | | window = (const float *)&window_hamm_hex; |
| | | |
| | | int window_size = 400; |
| | | int fft_size = 512; |
| | | int window_shift = 160; |
| | | |
| | | speech.load(din, len); |
| | | int i, j; |
| | | float tmp_feature[80]; |
| | | if (mode == 0 || mode == 2 || mode == 3) { |
| | | int ll = (speech.size() - 400) / 160 + 1; |
| | | int ll = (speech.size() - window_size) / window_shift + 1; |
| | | fqueue.reinit(ll); |
| | | } |
| | | |
| | | for (i = 0; i <= speech.size() - 400; i = i + window_shift) { |
| | | for (i = 0; i <= speech.size() - window_size; i = i + window_shift) { |
| | | float tmp_mean = 0; |
| | | for (j = 0; j < window_size; j++) { |
| | | tmp_mean += speech[i + j]; |
| | |
| | | pre_val = cur_val; |
| | | } |
| | | |
| | | fftwf_execute(p); |
| | | fftwf_execute_dft_r2c(plan, fft_input, fft_out); |
| | | |
| | | melspect((float *)fft_out, tmp_feature); |
| | | int tmp_flag = S_MIDDLE; |
| | |
| | | fqueue.push(tmp_feature, tmp_flag); |
| | | } |
| | | speech.update(i); |
| | | fftwf_free(fft_input); |
| | | fftwf_free(fft_out); |
| | | } |
| | | |
| | | bool FeatureExtract::fetch(Tensor<float> *&dout) |
| | |
| | | void FeatureExtract::melspect(float *din, float *dout) |
| | | { |
| | | float fftmag[256]; |
| | | // float tmp; |
| | | const float *melcoe = (const float *)melcoe_hex; |
| | | int i; |
| | | for (i = 0; i < 256; i++) { |
| | |
| | | SpeechWrap speech; |
| | | FeatureQueue fqueue; |
| | | int mode; |
| | | int fft_size = 512; |
| | | int window_size = 400; |
| | | int window_shift = 160; |
| | | |
| | | float *fft_input; |
| | | fftwf_complex *fft_out; |
| | | fftwf_plan p; |
| | | |
| | | void fftw_init(); |
| | | //void fftw_init(); |
| | | void melspect(float *din, float *dout); |
| | | void global_cmvn(float *din); |
| | | |
| | |
| | | FeatureExtract(int mode); |
| | | ~FeatureExtract(); |
| | | int size(); |
| | | int status(); |
| | | //int status(); |
| | | void reset(); |
| | | void insert(float *din, int len, int flag); |
| | | void insert(fftwf_plan plan, float *din, int len, int flag); |
| | | bool fetch(Tensor<float> *&dout); |
| | | }; |
| | | |
| | |
| | | { |
| | | ifstream in(filename); |
| | | loadVocabFromYaml(filename); |
| | | |
| | | /* |
| | | string line; |
| | | if (in) // 有该文件 |
| | | { |
| | | while (getline(in, line)) // line中不包括每行的换行符 |
| | | { |
| | | vocab.push_back(line); |
| | | } |
| | | } |
| | | else{ |
| | | printf("Cannot load vocab from: %s, there must be file vocab.txt", filename); |
| | | exit(-1); |
| | | } |
| | | */ |
| | | } |
| | | Vocab::~Vocab() |
| | | { |
| | |
| | | { |
| | | std::string msg; |
| | | float snippet_time; |
| | | }RPASR_RECOG_RESULT; |
| | | }FUNASR_RECOG_RESULT; |
| | | |
| | | |
| | | #ifdef _WIN32 |
| | |
| | | |
| | | } |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | #include "precomp.h" |
| | | #ifdef __cplusplus |
| | | |
| | | extern "C" { |
| | | #endif |
| | | |
| | | // APIs for qmasr |
| | | _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize) |
| | | { |
| | | Model* mm = create_model(szModelDir, nThreadNum, quantize); |
| | | return mm; |
| | | } |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback) |
| | | { |
| | | Model* pRecogObj = (Model*)handle; |
| | | if (!pRecogObj) |
| | | return nullptr; |
| | | |
| | | int32_t sampling_rate = -1; |
| | | Audio audio(1); |
| | | if (!audio.loadwav(szBuf, nLen, &sampling_rate)) |
| | | return nullptr; |
| | | //audio.split(); |
| | | |
| | | float* buff; |
| | | int len; |
| | | int flag=0; |
| | | FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT; |
| | | pResult->snippet_time = audio.get_time_len(); |
| | | int nStep = 0; |
| | | int nTotal = audio.get_queue_size(); |
| | | while (audio.fetch(buff, len, flag) > 0) { |
| | | //pRecogObj->reset(); |
| | | string msg = pRecogObj->forward(buff, len, flag); |
| | | pResult->msg += msg; |
| | | nStep++; |
| | | if (fnCallback) |
| | | fnCallback(nStep, nTotal); |
| | | } |
| | | |
| | | return pResult; |
| | | } |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback) |
| | | { |
| | | Model* pRecogObj = (Model*)handle; |
| | | if (!pRecogObj) |
| | | return nullptr; |
| | | |
| | | Audio audio(1); |
| | | if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate)) |
| | | return nullptr; |
| | | //audio.split(); |
| | | |
| | | float* buff; |
| | | int len; |
| | | int flag = 0; |
| | | FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT; |
| | | pResult->snippet_time = audio.get_time_len(); |
| | | int nStep = 0; |
| | | int nTotal = audio.get_queue_size(); |
| | | while (audio.fetch(buff, len, flag) > 0) { |
| | | //pRecogObj->reset(); |
| | | string msg = pRecogObj->forward(buff, len, flag); |
| | | pResult->msg += msg; |
| | | nStep++; |
| | | if (fnCallback) |
| | | fnCallback(nStep, nTotal); |
| | | } |
| | | |
| | | return pResult; |
| | | } |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback) |
| | | { |
| | | Model* pRecogObj = (Model*)handle; |
| | | if (!pRecogObj) |
| | | return nullptr; |
| | | |
| | | Audio audio(1); |
| | | if (!audio.loadpcmwav(szFileName, &sampling_rate)) |
| | | return nullptr; |
| | | //audio.split(); |
| | | |
| | | float* buff; |
| | | int len; |
| | | int flag = 0; |
| | | FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT; |
| | | pResult->snippet_time = audio.get_time_len(); |
| | | int nStep = 0; |
| | | int nTotal = audio.get_queue_size(); |
| | | while (audio.fetch(buff, len, flag) > 0) { |
| | | //pRecogObj->reset(); |
| | | string msg = pRecogObj->forward(buff, len, flag); |
| | | pResult->msg += msg; |
| | | nStep++; |
| | | if (fnCallback) |
| | | fnCallback(nStep, nTotal); |
| | | } |
| | | |
| | | return pResult; |
| | | } |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback) |
| | | { |
| | | Model* pRecogObj = (Model*)handle; |
| | | if (!pRecogObj) |
| | | return nullptr; |
| | | |
| | | int32_t sampling_rate = -1; |
| | | Audio audio(1); |
| | | if(!audio.loadwav(szWavfile, &sampling_rate)) |
| | | return nullptr; |
| | | //audio.split(); |
| | | |
| | | float* buff; |
| | | int len; |
| | | int flag = 0; |
| | | int nStep = 0; |
| | | int nTotal = audio.get_queue_size(); |
| | | FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT; |
| | | pResult->snippet_time = audio.get_time_len(); |
| | | while (audio.fetch(buff, len, flag) > 0) { |
| | | //pRecogObj->reset(); |
| | | string msg = pRecogObj->forward(buff, len, flag); |
| | | pResult->msg+= msg; |
| | | nStep++; |
| | | if (fnCallback) |
| | | fnCallback(nStep, nTotal); |
| | | } |
| | | |
| | | return pResult; |
| | | } |
| | | |
| | | _FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result) |
| | | { |
| | | if (!Result) |
| | | return 0; |
| | | |
| | | return 1; |
| | | } |
| | | |
| | | |
| | | _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result) |
| | | { |
| | | if (!Result) |
| | | return 0.0f; |
| | | |
| | | return ((FUNASR_RECOG_RESULT*)Result)->snippet_time; |
| | | } |
| | | |
| | | _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex) |
| | | { |
| | | FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result; |
| | | if(!pResult) |
| | | return nullptr; |
| | | |
| | | return pResult->msg.c_str(); |
| | | } |
| | | |
| | | _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result) |
| | | { |
| | | if (Result) |
| | | { |
| | | delete (FUNASR_RECOG_RESULT*)Result; |
| | | } |
| | | } |
| | | |
| | | _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle) |
| | | { |
| | | Model* pRecogObj = (Model*)handle; |
| | | |
| | | if (!pRecogObj) |
| | | return; |
| | | |
| | | delete pRecogObj; |
| | | } |
| | | |
| | | #ifdef __cplusplus |
| | | |
| | | } |
| | | #endif |
| | | |
| | |
| | | using namespace paraformer; |
| | | |
| | | ModelImp::ModelImp(const char* path,int nNumThread, bool quantize) |
| | | { |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{ |
| | | string model_path; |
| | | string cmvn_path; |
| | | string config_path; |
| | |
| | | cmvn_path = pathAppend(path, "am.mvn"); |
| | | config_path = pathAppend(path, "config.yaml"); |
| | | |
| | | fe = new FeatureExtract(3); |
| | | fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size); |
| | | fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size); |
| | | memset(fft_input, 0, sizeof(float) * fft_size); |
| | | plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE); |
| | | |
| | | //sessionOptions.SetInterOpNumThreads(1); |
| | | sessionOptions.SetIntraOpNumThreads(nNumThread); |
| | |
| | | |
| | | #ifdef _WIN32 |
| | | wstring wstrPath = strToWstr(model_path); |
| | | m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions); |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions); |
| | | #else |
| | | m_session = new Ort::Session(env, model_path.c_str(), sessionOptions); |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions); |
| | | #endif |
| | | |
| | | string strName; |
| | | getInputName(m_session, strName); |
| | | getInputName(m_session.get(), strName); |
| | | m_strInputNames.push_back(strName.c_str()); |
| | | getInputName(m_session, strName,1); |
| | | getInputName(m_session.get(), strName,1); |
| | | m_strInputNames.push_back(strName); |
| | | |
| | | getOutputName(m_session, strName); |
| | | getOutputName(m_session.get(), strName); |
| | | m_strOutputNames.push_back(strName); |
| | | getOutputName(m_session, strName,1); |
| | | getOutputName(m_session.get(), strName,1); |
| | | m_strOutputNames.push_back(strName); |
| | | |
| | | for (auto& item : m_strInputNames) |
| | |
| | | |
| | | ModelImp::~ModelImp() |
| | | { |
| | | if(fe) |
| | | delete fe; |
| | | if (m_session) |
| | | { |
| | | delete m_session; |
| | | m_session = nullptr; |
| | | } |
| | | if(vocab) |
| | | delete vocab; |
| | | fftwf_free(fft_input); |
| | | fftwf_free(fft_out); |
| | | fftwf_destroy_plan(plan); |
| | | fftwf_cleanup(); |
| | | } |
| | | |
| | | void ModelImp::reset() |
| | | { |
| | | fe->reset(); |
| | | } |
| | | |
| | | void ModelImp::apply_lfr(Tensor<float>*& din) |
| | |
| | | |
| | | string ModelImp::forward(float* din, int len, int flag) |
| | | { |
| | | |
| | | Tensor<float>* in; |
| | | fe->insert(din, len, flag); |
| | | FeatureExtract* fe = new FeatureExtract(3); |
| | | fe->reset(); |
| | | fe->insert(plan, din, len, flag); |
| | | fe->fetch(in); |
| | | apply_lfr(in); |
| | | apply_cmvn(in); |
| | | Ort::RunOptions run_option; |
| | | |
| | | #ifdef _WIN_X86 |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| | | #else |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
| | | #endif |
| | | |
| | | std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] }; |
| | | Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo, |
| | |
| | | auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size()); |
| | | std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); |
| | | |
| | | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float* floatData = outputTensor[0].GetTensorMutableData<float>(); |
| | | auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>(); |
| | |
| | | result = ""; |
| | | } |
| | | |
| | | |
| | | if(in) |
| | | if(in){ |
| | | delete in; |
| | | in = nullptr; |
| | | } |
| | | if(fe){ |
| | | delete fe; |
| | | fe = nullptr; |
| | | } |
| | | |
| | | return result; |
| | | } |
| | |
| | | |
| | | class ModelImp : public Model { |
| | | private: |
| | | FeatureExtract* fe; |
| | | int fft_size=512; |
| | | float *fft_input; |
| | | fftwf_complex *fft_out; |
| | | fftwf_plan plan; |
| | | |
| | | Vocab* vocab; |
| | | vector<float> means_list; |
| | |
| | | |
| | | string greedy_search( float* in, int nLen); |
| | | |
| | | #ifdef _WIN_X86 |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| | | #else |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
| | | #endif |
| | | |
| | | Ort::Session* m_session = nullptr; |
| | | Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer"); |
| | | Ort::SessionOptions sessionOptions = Ort::SessionOptions(); |
| | | std::unique_ptr<Ort::Session> m_session; |
| | | Ort::Env env_; |
| | | Ort::SessionOptions sessionOptions; |
| | | |
| | | vector<string> m_strInputNames, m_strOutputNames; |
| | | vector<const char*> m_szInputNames; |
| | | vector<const char*> m_szOutputNames; |
| | | //string m_strInputName, m_strInputNameLen; |
| | | //string m_strOutputName, m_strOutputNameLen; |
| | | |
| | | public: |
| | | ModelImp(const char* path, int nNumThread=0, bool quantize=false); |
| | |
| | | #include "FeatureQueue.h" |
| | | #include "SpeechWrap.h" |
| | | #include <Audio.h> |
| | | #include "resample.h" |
| | | #include "Model.h" |
| | | #include "paraformer_onnx.h" |
| | | #include "librapidasrapi.h" |
| | | #include "libfunasrapi.h" |
| | | |
| | | |
| | | using namespace paraformer; |
| New file |
| | |
| | | /** |
| | | * Copyright 2013 Pegah Ghahremani |
| | | * 2014 IMSL, PKU-HKUST (author: Wei Shi) |
| | | * 2014 Yanqing Sun, Junjie Wang |
| | | * 2014 Johns Hopkins University (author: Daniel Povey) |
| | | * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang) |
| | | * |
| | | * See LICENSE for clarification regarding multiple authors |
| | | * |
| | | * Licensed under the Apache License, Version 2.0 (the "License"); |
| | | * you may not use this file except in compliance with the License. |
| | | * You may obtain a copy of the License at |
| | | * |
| | | * http://www.apache.org/licenses/LICENSE-2.0 |
| | | * |
| | | * Unless required by applicable law or agreed to in writing, software |
| | | * distributed under the License is distributed on an "AS IS" BASIS, |
| | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| | | * See the License for the specific language governing permissions and |
| | | * limitations under the License. |
| | | */ |
| | | // this file is copied and modified from |
| | | // kaldi/src/feat/resample.cc |
| | | |
| | | #include "resample.h" |
| | | |
| | | #include <assert.h> |
| | | #include <math.h> |
| | | #include <stdio.h> |
| | | |
| | | #include <cstdlib> |
| | | #include <type_traits> |
| | | |
| | | #ifndef M_2PI |
| | | #define M_2PI 6.283185307179586476925286766559005 |
| | | #endif |
| | | |
| | | #ifndef M_PI |
| | | #define M_PI 3.1415926535897932384626433832795 |
| | | #endif |
| | | |
| | | template <class I> |
| | | I Gcd(I m, I n) { |
| | | // this function is copied from kaldi/src/base/kaldi-math.h |
| | | if (m == 0 || n == 0) { |
| | | if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. |
| | | fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n"); |
| | | exit(-1); |
| | | } |
| | | return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); |
| | | // return absolute value of whichever is nonzero |
| | | } |
| | | // could use compile-time assertion |
| | | // but involves messing with complex template stuff. |
| | | static_assert(std::is_integral<I>::value, ""); |
| | | while (1) { |
| | | m %= n; |
| | | if (m == 0) return (n > 0 ? n : -n); |
| | | n %= m; |
| | | if (n == 0) return (m > 0 ? m : -m); |
| | | } |
| | | } |
| | | |
| | | /// Returns the least common multiple of two integers. Will |
| | | /// crash unless the inputs are positive. |
| | | template <class I> |
| | | I Lcm(I m, I n) { |
| | | // This function is copied from kaldi/src/base/kaldi-math.h |
| | | assert(m > 0 && n > 0); |
| | | I gcd = Gcd(m, n); |
| | | return gcd * (m / gcd) * (n / gcd); |
| | | } |
| | | |
| | | static float DotProduct(const float *a, const float *b, int32_t n) { |
| | | float sum = 0; |
| | | for (int32_t i = 0; i != n; ++i) { |
| | | sum += a[i] * b[i]; |
| | | } |
| | | return sum; |
| | | } |
| | | |
| | | LinearResample::LinearResample(int32_t samp_rate_in_hz, |
| | | int32_t samp_rate_out_hz, float filter_cutoff_hz, |
| | | int32_t num_zeros) |
| | | : samp_rate_in_(samp_rate_in_hz), |
| | | samp_rate_out_(samp_rate_out_hz), |
| | | filter_cutoff_(filter_cutoff_hz), |
| | | num_zeros_(num_zeros) { |
| | | assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 && |
| | | filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz && |
| | | filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0); |
| | | |
| | | // base_freq is the frequency of the repeating unit, which is the gcd |
| | | // of the input frequencies. |
| | | int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_); |
| | | input_samples_in_unit_ = samp_rate_in_ / base_freq; |
| | | output_samples_in_unit_ = samp_rate_out_ / base_freq; |
| | | |
| | | SetIndexesAndWeights(); |
| | | Reset(); |
| | | } |
| | | |
| | | void LinearResample::SetIndexesAndWeights() { |
| | | first_index_.resize(output_samples_in_unit_); |
| | | weights_.resize(output_samples_in_unit_); |
| | | |
| | | double window_width = num_zeros_ / (2.0 * filter_cutoff_); |
| | | |
| | | for (int32_t i = 0; i < output_samples_in_unit_; i++) { |
| | | double output_t = i / static_cast<double>(samp_rate_out_); |
| | | double min_t = output_t - window_width, max_t = output_t + window_width; |
| | | // we do ceil on the min and floor on the max, because if we did it |
| | | // the other way around we would unnecessarily include indexes just |
| | | // outside the window, with zero coefficients. It's possible |
| | | // if the arguments to the ceil and floor expressions are integers |
| | | // (e.g. if filter_cutoff_ has an exact ratio with the sample rates), |
| | | // that we unnecessarily include something with a zero coefficient, |
| | | // but this is only a slight efficiency issue. |
| | | int32_t min_input_index = ceil(min_t * samp_rate_in_), |
| | | max_input_index = floor(max_t * samp_rate_in_), |
| | | num_indices = max_input_index - min_input_index + 1; |
| | | first_index_[i] = min_input_index; |
| | | weights_[i].resize(num_indices); |
| | | for (int32_t j = 0; j < num_indices; j++) { |
| | | int32_t input_index = min_input_index + j; |
| | | double input_t = input_index / static_cast<double>(samp_rate_in_), |
| | | delta_t = input_t - output_t; |
| | | // sign of delta_t doesn't matter. |
| | | weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_; |
| | | } |
| | | } |
| | | } |
| | | |
| | | /** Here, t is a time in seconds representing an offset from |
| | | the center of the windowed filter function, and FilterFunction(t) |
| | | returns the windowed filter function, described |
| | | in the header as h(t) = f(t)g(t), evaluated at t. |
| | | */ |
| | | float LinearResample::FilterFunc(float t) const { |
| | | float window, // raised-cosine (Hanning) window of width |
| | | // num_zeros_/2*filter_cutoff_ |
| | | filter; // sinc filter function |
| | | if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) |
| | | window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); |
| | | else |
| | | window = 0.0; // outside support of window function |
| | | if (t != 0) |
| | | filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); |
| | | else |
| | | filter = 2 * filter_cutoff_; // limit of the function at t = 0 |
| | | return filter * window; |
| | | } |
| | | |
| | | void LinearResample::Reset() { |
| | | input_sample_offset_ = 0; |
| | | output_sample_offset_ = 0; |
| | | input_remainder_.resize(0); |
| | | } |
| | | |
| | | void LinearResample::Resample(const float *input, int32_t input_dim, bool flush, |
| | | std::vector<float> *output) { |
| | | int64_t tot_input_samp = input_sample_offset_ + input_dim, |
| | | tot_output_samp = GetNumOutputSamples(tot_input_samp, flush); |
| | | |
| | | assert(tot_output_samp >= output_sample_offset_); |
| | | |
| | | output->resize(tot_output_samp - output_sample_offset_); |
| | | |
| | | // samp_out is the index into the total output signal, not just the part |
| | | // of it we are producing here. |
| | | for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp; |
| | | samp_out++) { |
| | | int64_t first_samp_in; |
| | | int32_t samp_out_wrapped; |
| | | GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped); |
| | | const std::vector<float> &weights = weights_[samp_out_wrapped]; |
| | | // first_input_index is the first index into "input" that we have a weight |
| | | // for. |
| | | int32_t first_input_index = |
| | | static_cast<int32_t>(first_samp_in - input_sample_offset_); |
| | | float this_output; |
| | | if (first_input_index >= 0 && |
| | | first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) { |
| | | this_output = |
| | | DotProduct(input + first_input_index, weights.data(), weights.size()); |
| | | } else { // Handle edge cases. |
| | | this_output = 0.0; |
| | | for (int32_t i = 0; i < static_cast<int32_t>(weights.size()); i++) { |
| | | float weight = weights[i]; |
| | | int32_t input_index = first_input_index + i; |
| | | if (input_index < 0 && |
| | | static_cast<int32_t>(input_remainder_.size()) + input_index >= 0) { |
| | | this_output += |
| | | weight * input_remainder_[input_remainder_.size() + input_index]; |
| | | } else if (input_index >= 0 && input_index < input_dim) { |
| | | this_output += weight * input[input_index]; |
| | | } else if (input_index >= input_dim) { |
| | | // We're past the end of the input and are adding zero; should only |
| | | // happen if the user specified flush == true, or else we would not |
| | | // be trying to output this sample. |
| | | assert(flush); |
| | | } |
| | | } |
| | | } |
| | | int32_t output_index = |
| | | static_cast<int32_t>(samp_out - output_sample_offset_); |
| | | (*output)[output_index] = this_output; |
| | | } |
| | | |
| | | if (flush) { |
| | | Reset(); // Reset the internal state. |
| | | } else { |
| | | SetRemainder(input, input_dim); |
| | | input_sample_offset_ = tot_input_samp; |
| | | output_sample_offset_ = tot_output_samp; |
| | | } |
| | | } |
| | | |
| | | int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp, |
| | | bool flush) const { |
| | | // For exact computation, we measure time in "ticks" of 1.0 / tick_freq, |
| | | // where tick_freq is the least common multiple of samp_rate_in_ and |
| | | // samp_rate_out_. |
| | | int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_); |
| | | int32_t ticks_per_input_period = tick_freq / samp_rate_in_; |
| | | |
| | | // work out the number of ticks in the time interval |
| | | // [ 0, input_num_samp/samp_rate_in_ ). |
| | | int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period; |
| | | if (!flush) { |
| | | float window_width = num_zeros_ / (2.0 * filter_cutoff_); |
| | | // To count the window-width in ticks we take the floor. This |
| | | // is because since we're looking for the largest integer num-out-samp |
| | | // that fits in the interval, which is open on the right, a reduction |
| | | // in interval length of less than a tick will never make a difference. |
| | | // For example, the largest integer in the interval [ 0, 2 ) and the |
| | | // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one). |
| | | // So when we're subtracting the window-width we can ignore the fractional |
| | | // part. |
| | | int32_t window_width_ticks = floor(window_width * tick_freq); |
| | | // The time-period of the output that we can sample gets reduced |
| | | // by the window-width (which is actually the distance from the |
| | | // center to the edge of the windowing function) if we're not |
| | | // "flushing the output". |
| | | interval_length_in_ticks -= window_width_ticks; |
| | | } |
| | | if (interval_length_in_ticks <= 0) return 0; |
| | | |
| | | int32_t ticks_per_output_period = tick_freq / samp_rate_out_; |
| | | // Get the last output-sample in the closed interval, i.e. replacing [ ) with |
| | | // [ ]. Note: integer division rounds down. See |
| | | // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of |
| | | // the notation. |
| | | int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period; |
| | | // We need the last output-sample in the open interval, so if it takes us to |
| | | // the end of the interval exactly, subtract one. |
| | | if (last_output_samp * ticks_per_output_period == interval_length_in_ticks) |
| | | last_output_samp--; |
| | | |
| | | // First output-sample index is zero, so the number of output samples |
| | | // is the last output-sample plus one. |
| | | int64_t num_output_samp = last_output_samp + 1; |
| | | return num_output_samp; |
| | | } |
| | | |
| | | // inline |
| | | void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in, |
| | | int32_t *samp_out_wrapped) const { |
| | | // A unit is the smallest nonzero amount of time that is an exact |
| | | // multiple of the input and output sample periods. The unit index |
| | | // is the answer to "which numbered unit we are in". |
| | | int64_t unit_index = samp_out / output_samples_in_unit_; |
| | | // samp_out_wrapped is equal to samp_out % output_samples_in_unit_ |
| | | *samp_out_wrapped = |
| | | static_cast<int32_t>(samp_out - unit_index * output_samples_in_unit_); |
| | | *first_samp_in = |
| | | first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_; |
| | | } |
| | | |
| | | void LinearResample::SetRemainder(const float *input, int32_t input_dim) { |
| | | std::vector<float> old_remainder(input_remainder_); |
| | | // max_remainder_needed is the width of the filter from side to side, |
| | | // measured in input samples. you might think it should be half that, |
| | | // but you have to consider that you might be wanting to output samples |
| | | // that are "in the past" relative to the beginning of the latest |
| | | // input... anyway, storing more remainder than needed is not harmful. |
| | | int32_t max_remainder_needed = |
| | | ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_); |
| | | input_remainder_.resize(max_remainder_needed); |
| | | for (int32_t index = -static_cast<int32_t>(input_remainder_.size()); |
| | | index < 0; index++) { |
| | | // we interpret "index" as an offset from the end of "input" and |
| | | // from the end of input_remainder_. |
| | | int32_t input_index = index + input_dim; |
| | | if (input_index >= 0) { |
| | | input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] = |
| | | input[input_index]; |
| | | } else if (input_index + static_cast<int32_t>(old_remainder.size()) >= 0) { |
| | | input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] = |
| | | old_remainder[input_index + |
| | | static_cast<int32_t>(old_remainder.size())]; |
| | | // else leave it at zero. |
| | | } |
| | | } |
| | | } |
| New file |
| | |
| | | /** |
| | | * Copyright 2013 Pegah Ghahremani |
| | | * 2014 IMSL, PKU-HKUST (author: Wei Shi) |
| | | * 2014 Yanqing Sun, Junjie Wang |
| | | * 2014 Johns Hopkins University (author: Daniel Povey) |
| | | * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang) |
| | | * |
| | | * See LICENSE for clarification regarding multiple authors |
| | | * |
| | | * Licensed under the Apache License, Version 2.0 (the "License"); |
| | | * you may not use this file except in compliance with the License. |
| | | * You may obtain a copy of the License at |
| | | * |
| | | * http://www.apache.org/licenses/LICENSE-2.0 |
| | | * |
| | | * Unless required by applicable law or agreed to in writing, software |
| | | * distributed under the License is distributed on an "AS IS" BASIS, |
| | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| | | * See the License for the specific language governing permissions and |
| | | * limitations under the License. |
| | | */ |
| | | // this file is copied and modified from |
| | | // kaldi/src/feat/resample.h |
| | | |
| | | #include <cstdint> |
| | | #include <vector> |
| | | |
| | | |
| | | /* |
| | | We require that the input and output sampling rate be specified as |
| | | integers, as this is an easy way to specify that their ratio be rational. |
| | | */ |
| | | |
| | | class LinearResample { |
| | | public: |
| | | /// Constructor. We make the input and output sample rates integers, because |
| | | /// we are going to need to find a common divisor. This should just remind |
| | | /// you that they need to be integers. The filter cutoff needs to be less |
| | | /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros |
| | | /// controls the sharpness of the filter, more == sharper but less efficient. |
| | | /// We suggest around 4 to 10 for normal use. |
| | | LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz, |
| | | float filter_cutoff_hz, int32_t num_zeros); |
| | | |
| | | /// Calling the function Reset() resets the state of the object prior to |
| | | /// processing a new signal; it is only necessary if you have called |
| | | /// Resample(x, x_size, false, y) for some signal, leading to a remainder of |
| | | /// the signal being called, but then abandon processing the signal before |
| | | /// calling Resample(x, x_size, true, y) for the last piece. Call it |
| | | /// unnecessarily between signals will not do any harm. |
| | | void Reset(); |
| | | |
| | | /// This function does the resampling. If you call it with flush == true and |
| | | /// you have never called it with flush == false, it just resamples the input |
| | | /// signal (it resizes the output to a suitable number of samples). |
| | | /// |
| | | /// You can also use this function to process a signal a piece at a time. |
| | | /// suppose you break it into piece1, piece2, ... pieceN. You can call |
| | | /// \code{.cc} |
| | | /// Resample(piece1, piece1_size, false, &output1); |
| | | /// Resample(piece2, piece2_size, false, &output2); |
| | | /// Resample(piece3, piece3_size, true, &output3); |
| | | /// \endcode |
| | | /// If you call it with flush == false, it won't output the last few samples |
| | | /// but will remember them, so that if you later give it a second piece of |
| | | /// the input signal it can process it correctly. |
| | | /// If your most recent call to the object was with flush == false, it will |
| | | /// have internal state; you can remove this by calling Reset(). |
| | | /// Empty input is acceptable. |
| | | void Resample(const float *input, int32_t input_dim, bool flush, |
| | | std::vector<float> *output); |
| | | |
| | | //// Return the input and output sampling rates (for checks, for example) |
| | | int32_t GetInputSamplingRate() const { return samp_rate_in_; } |
| | | int32_t GetOutputSamplingRate() const { return samp_rate_out_; } |
| | | |
| | | private: |
| | | void SetIndexesAndWeights(); |
| | | |
| | | float FilterFunc(float) const; |
| | | |
| | | /// This function outputs the number of output samples we will output |
| | | /// for a signal with "input_num_samp" input samples. If flush == true, |
| | | /// we return the largest n such that |
| | | /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ), |
| | | /// and note that the interval is half-open. If flush == false, |
| | | /// define window_width as num_zeros / (2.0 * filter_cutoff_); |
| | | /// we return the largest n such that (n/samp_rate_out_) is in the interval |
| | | /// [ 0, input_num_samp/samp_rate_in_ - window_width ). |
| | | int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const; |
| | | |
| | | /// Given an output-sample index, this function outputs to *first_samp_in the |
| | | /// first input-sample index that we have a weight on (may be negative), |
| | | /// and to *samp_out_wrapped the index into weights_ where we can get the |
| | | /// corresponding weights on the input. |
| | | inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in, |
| | | int32_t *samp_out_wrapped) const; |
| | | |
| | | void SetRemainder(const float *input, int32_t input_dim); |
| | | |
| | | private: |
| | | // The following variables are provided by the user. |
| | | int32_t samp_rate_in_; |
| | | int32_t samp_rate_out_; |
| | | float filter_cutoff_; |
| | | int32_t num_zeros_; |
| | | |
| | | int32_t input_samples_in_unit_; ///< The number of input samples in the |
| | | ///< smallest repeating unit: num_samp_in_ = |
| | | ///< samp_rate_in_hz / Gcd(samp_rate_in_hz, |
| | | ///< samp_rate_out_hz) |
| | | |
| | | int32_t output_samples_in_unit_; ///< The number of output samples in the |
| | | ///< smallest repeating unit: num_samp_out_ |
| | | ///< = samp_rate_out_hz / |
| | | ///< Gcd(samp_rate_in_hz, samp_rate_out_hz) |
| | | |
| | | /// The first input-sample index that we sum over, for this output-sample |
| | | /// index. May be negative; any truncation at the beginning is handled |
| | | /// separately. This is just for the first few output samples, but we can |
| | | /// extrapolate the correct input-sample index for arbitrary output samples. |
| | | std::vector<int32_t> first_index_; |
| | | |
| | | /// Weights on the input samples, for this output-sample index. |
| | | std::vector<std::vector<float>> weights_; |
| | | |
| | | // the following variables keep track of where we are in a particular signal, |
| | | // if it is being provided over multiple calls to Resample(). |
| | | |
| | | int64_t input_sample_offset_; ///< The number of input samples we have |
| | | ///< already received for this signal |
| | | ///< (including anything in remainder_) |
| | | int64_t output_sample_offset_; ///< The number of samples we have already |
| | | ///< output for this signal. |
| | | std::vector<float> input_remainder_; ///< A small trailing part of the |
| | | ///< previously seen input signal. |
| | | }; |
| | |
| | | endif() |
| | | endif() |
| | | |
| | | set(EXTRA_LIBS rapidasr) |
| | | set(EXTRA_LIBS funasr) |
| | | |
| | | |
| | | include_directories(${CMAKE_SOURCE_DIR}/include) |
| | |
| | | #include <win_func.h> |
| | | #endif |
| | | |
| | | #include "librapidasrapi.h" |
| | | #include "libfunasrapi.h" |
| | | |
| | | #include <iostream> |
| | | #include <fstream> |
| | |
| | | // is quantize |
| | | bool quantize = false; |
| | | istringstream(argv[3]) >> boolalpha >> quantize; |
| | | RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize); |
| | | FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize); |
| | | |
| | | if (!AsrHanlde) |
| | | { |
| | |
| | | gettimeofday(&start, NULL); |
| | | float snippet_time = 0.0f; |
| | | |
| | | RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL); |
| | | FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL); |
| | | |
| | | gettimeofday(&end, NULL); |
| | | |
| | | if (Result) |
| | | { |
| | | string msg = RapidAsrGetResult(Result, 0); |
| | | string msg = FunASRGetResult(Result, 0); |
| | | setbuf(stdout, NULL); |
| | | cout << "Result: \""; |
| | | cout << msg << "\"." << endl; |
| | | snippet_time = RapidAsrGetRetSnippetTime(Result); |
| | | RapidAsrFreeResult(Result); |
| | | printf("Result: %s \n", msg.c_str()); |
| | | snippet_time = FunASRGetRetSnippetTime(Result); |
| | | FunASRFreeResult(Result); |
| | | } |
| | | else |
| | | { |
| | | cout <<"no return data!"; |
| | | } |
| | | |
| | | //char* buff = nullptr; |
| | | //int len = 0; |
| | | //ifstream ifs(argv[2], std::ios::binary | std::ios::in); |
| | | //if (ifs.is_open()) |
| | | //{ |
| | | // ifs.seekg(0, std::ios::end); |
| | | // len = ifs.tellg(); |
| | | // ifs.seekg(0, std::ios::beg); |
| | | |
| | | // buff = new char[len]; |
| | | |
| | | // ifs.read(buff, len); |
| | | |
| | | |
| | | // //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL); |
| | | |
| | | // RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL); |
| | | // //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL); |
| | | // gettimeofday(&end, NULL); |
| | | // |
| | | // if (Result) |
| | | // { |
| | | // string msg = RapidAsrGetResult(Result, 0); |
| | | // setbuf(stdout, NULL); |
| | | // cout << "Result: \""; |
| | | // cout << msg << endl; |
| | | // cout << "\"." << endl; |
| | | // snippet_time = RapidAsrGetRetSnippetTime(Result); |
| | | // RapidAsrFreeResult(Result); |
| | | // } |
| | | // else |
| | | // { |
| | | // cout <<"no return data!"; |
| | | // } |
| | | |
| | | // |
| | | //delete[]buff; |
| | | //} |
| | | |
| | | printf("Audio length %lfs.\n", (double)snippet_time); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | |
| | | printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000); |
| | | printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000)); |
| | | |
| | | RapidAsrUninit(AsrHanlde); |
| | | FunASRUninit(AsrHanlde); |
| | | |
| | | return 0; |
| | | } |
| | | |
| | | |
| | | |
| | |
| | | #include <win_func.h> |
| | | #endif |
| | | |
| | | #include "librapidasrapi.h" |
| | | #include "libfunasrapi.h" |
| | | |
| | | #include <iostream> |
| | | #include <fstream> |
| | |
| | | bool quantize = false; |
| | | istringstream(argv[3]) >> boolalpha >> quantize; |
| | | |
| | | RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize); |
| | | FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize); |
| | | if (!AsrHanlde) |
| | | { |
| | | printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]); |
| | |
| | | // warm up |
| | | for (size_t i = 0; i < 30; i++) |
| | | { |
| | | RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL); |
| | | FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL); |
| | | } |
| | | |
| | | // forward |
| | |
| | | for (size_t i = 0; i < wav_list.size(); i++) |
| | | { |
| | | gettimeofday(&start, NULL); |
| | | RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL); |
| | | FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL); |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | | long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | | total_time += taking_micros; |
| | | |
| | | if(Result){ |
| | | string msg = RapidAsrGetResult(Result, 0); |
| | | printf("Result: %s \n", msg); |
| | | string msg = FunASRGetResult(Result, 0); |
| | | printf("Result: %s \n", msg.c_str()); |
| | | |
| | | snippet_time = RapidAsrGetRetSnippetTime(Result); |
| | | snippet_time = FunASRGetRetSnippetTime(Result); |
| | | total_length += snippet_time; |
| | | RapidAsrFreeResult(Result); |
| | | FunASRFreeResult(Result); |
| | | }else{ |
| | | cout <<"No return data!"; |
| | | } |
| | |
| | | printf("total_time_comput %ld ms.\n", total_time / 1000); |
| | | printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000)); |
| | | |
| | | RapidAsrUninit(AsrHanlde); |
| | | FunASRUninit(AsrHanlde); |
| | | return 0; |
| | | } |
| New file |
| | |
| | | import grpc |
| | | import json |
| | | import time |
| | | import asyncio |
| | | import soundfile as sf |
| | | import argparse |
| | | |
| | | from grpc_client import transcribe_audio_bytes |
| | | from paraformer_pb2_grpc import ASRStub |
| | | |
| | | # send the audio data once |
| | | async def grpc_rec(wav_scp, grpc_uri, asr_user, language): |
| | | with grpc.insecure_channel(grpc_uri) as channel: |
| | | stub = ASRStub(channel) |
| | | for line in wav_scp: |
| | | wav_file = line.split()[1] |
| | | wav, _ = sf.read(wav_file, dtype='int16') |
| | | |
| | | b = time.time() |
| | | response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False) |
| | | resp = response.next() |
| | | text = '' |
| | | if 'decoding' == resp.action: |
| | | resp = response.next() |
| | | if 'finish' == resp.action: |
| | | text = json.loads(resp.sentence)['text'] |
| | | response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True) |
| | | res= {'text': text, 'time': time.time() - b} |
| | | print(res) |
| | | |
| | | async def test(args): |
| | | wav_scp = open(args.wav_scp, "r").readlines() |
| | | uri = '{}:{}'.format(args.host, args.port) |
| | | res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN') |
| | | |
| | | if __name__ == '__main__': |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument("--host", |
| | | type=str, |
| | | default="127.0.0.1", |
| | | required=False, |
| | | help="grpc server host ip") |
| | | parser.add_argument("--port", |
| | | type=int, |
| | | default=10108, |
| | | required=False, |
| | | help="grpc server port") |
| | | parser.add_argument("--user_allowed", |
| | | type=str, |
| | | default="project1_user1", |
| | | help="allowed user for grpc client") |
| | | parser.add_argument("--sample_rate", |
| | | type=int, |
| | | default=16000, |
| | | help="audio sample_rate from client") |
| | | parser.add_argument("--wav_scp", |
| | | type=str, |
| | | required=True, |
| | | help="audio wav scp") |
| | | args = parser.parse_args() |
| | | |
| | | asyncio.run(test(args)) |
| | |
| | | else: |
| | | asr_result = "" |
| | | elif self.backend == "onnxruntime": |
| | | from rapid_paraformer.utils.frontend import load_bytes |
| | | from funasr_onnx.utils.frontend import load_bytes |
| | | array = load_bytes(tmp_data) |
| | | asr_result = self.inference_16k_pipeline(array)[0] |
| | | end_time = int(round(time.time() * 1000)) |
| | |
| | | |
| | | [FunASR](https://github.com/alibaba-damo-academy/FunASR) hopes to build a bridge between academic research and industrial applications on speech recognition. By supporting the training & finetuning of the industrial-grade speech recognition model released on ModelScope, researchers and developers can conduct research and production of speech recognition models more conveniently, and promote the development of speech recognition ecology. ASR for Fun! |
| | | |
| | | ### Introduction |
| | | - Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary). |
| | | |
| | | ### Steps: |
| | | 1. Export the model. |
| | |
| | | |
| | | install from pip |
| | | ```shell |
| | | pip install --upgrade funasr_torch -i https://pypi.Python.org/simple |
| | | pip install -U funasr_torch |
| | | # For the users in China, you could install with the command: |
| | | # pip install -U funasr_torch -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | |
| | | ``` |
| | | or install from source code |
| | | |
| | | ```shell |
| | | git clone https://github.com/alibaba/FunASR.git && cd FunASR |
| | | cd funasr/runtime/python/libtorch |
| | | python setup.py build |
| | | python setup.py install |
| | | pip install -e ./ |
| | | # For the users in China, you could install with the command: |
| | | # pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | |
| | | ``` |
| | | |
| | | 3. Run the demo. |
| | |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | plot_timestamp_to: str = "", |
| | | pred_bias: int = 1, |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 1, |
| | | ): |
| | |
| | | self.batch_size = batch_size |
| | | self.device_id = device_id |
| | | self.plot_timestamp_to = plot_timestamp_to |
| | | self.pred_bias = pred_bias |
| | | if "predictor_bias" in config['model_conf'].keys(): |
| | | self.pred_bias = config['model_conf']['predictor_bias'] |
| | | else: |
| | | self.pred_bias = 0 |
| | | |
| | | def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: |
| | | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) |
| | |
| | | ): |
| | | check_argument_types() |
| | | |
| | | # self.token_list = self.load_token(token_path) |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | self.token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | if self.unk_symbol not in token2id: |
| | | raise TokenIDConverterError( |
| | | f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | unk_id = token2id[self.unk_symbol] |
| | | return [token2id.get(i, unk_id) for i in tokens] |
| | | |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | |
| | | class CharTokenizer(): |
| | |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='torch_paraformer'): |
| | | def get_logger(name='funasr_torch'): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| | |
| | | |
| | | setuptools.setup( |
| | | name='funasr_torch', |
| | | version='0.0.3', |
| | | version='0.0.4', |
| | | platforms="Any", |
| | | url="https://github.com/alibaba-damo-academy/FunASR.git", |
| | | author="Speech Lab, Alibaba Group, China", |
| | | author="Speech Lab of DAMO Academy, Alibaba Group", |
| | | author_email="funasr@list.alibaba-inc.com", |
| | | description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", |
| | | license="The MIT License", |
| | |
| | | "PyYAML>=5.1.2", "torch-quant >= 0.4.0"], |
| | | packages=find_packages(include=["torch_paraformer*"]), |
| | | keywords=[ |
| | | 'funasr,paraformer, funasr_torch' |
| | | 'funasr, paraformer, funasr_torch' |
| | | ], |
| | | classifiers=[ |
| | | 'Programming Language :: Python :: 3.6', |
| | |
| | | ## Using funasr with ONNXRuntime |
| | | |
| | | |
| | | ### Introduction |
| | | - Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary). |
| | | |
| | | |
| | | ### Steps: |
| | | 1. Export the model. |
| | | - Command: (`Tips`: torch >= 1.11.0 is required.) |
| | |
| | | |
| | | install from pip |
| | | ```shell |
| | | pip install --upgrade funasr_onnx -i https://pypi.Python.org/simple |
| | | pip install -U funasr_onnx |
| | | # For the users in China, you could install with the command: |
| | | # pip install -U funasr_onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | |
| | | ``` |
| | | |
| | | or install from source code |
| | |
| | | ```shell |
| | | git clone https://github.com/alibaba/FunASR.git && cd FunASR |
| | | cd funasr/runtime/python/onnxruntime |
| | | python setup.py build |
| | | python setup.py install |
| | | pip install -e ./ |
| | | # For the users in China, you could install with the command: |
| | | # pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple |
| | | |
| | | ``` |
| | | |
| | | 3. Run the demo. |
| | |
| | | from funasr_onnx import Paraformer |
| | | |
| | | |
| | | model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | |
| | | model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0) # cpu |
| | |
| | | |
| | | # when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps |
| | | # model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png") |
| | | |
| | | |
| | | wav_path = "YourPath/xx.wav" |
| | | |
| New file |
| | |
| | | from funasr_onnx import CT_Transformer |
| | | |
| | | model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" |
| | | model = CT_Transformer(model_dir) |
| | | |
| | | text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益" |
| | | result = model(text_in) |
| | | print(result[0]) |
| New file |
| | |
| | | from funasr_onnx import CT_Transformer_VadRealtime |
| | | |
| | | model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727" |
| | | model = CT_Transformer_VadRealtime(model_dir) |
| | | |
| | | text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益" |
| | | |
| | | vads = text_in.split("|") |
| | | rec_result_all="" |
| | | param_dict = {"cache": []} |
| | | for vad in vads: |
| | | result = model(vad, param_dict=param_dict) |
| | | rec_result_all += result[0] |
| | | |
| | | print(rec_result_all) |
| New file |
| | |
| | | import soundfile |
| | | from funasr_onnx import Fsmn_vad |
| | | |
| | | |
| | | model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" |
| | | wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav" |
| | | model = Fsmn_vad(model_dir) |
| | | |
| | | #offline vad |
| | | result = model(wav_path) |
| | | print(result) |
| New file |
| | |
| | | import soundfile |
| | | from funasr_onnx import Fsmn_vad_online |
| | | |
| | | |
| | | model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" |
| | | wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav" |
| | | model = Fsmn_vad_online(model_dir) |
| | | |
| | | |
| | | ##online vad |
| | | speech, sample_rate = soundfile.read(wav_path) |
| | | speech_length = speech.shape[0] |
| | | # |
| | | sample_offset = 0 |
| | | step = 1600 |
| | | param_dict = {'in_cache': []} |
| | | for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)): |
| | | if sample_offset + step >= speech_length - 1: |
| | | step = speech_length - sample_offset |
| | | is_final = True |
| | | else: |
| | | is_final = False |
| | | param_dict['is_final'] = is_final |
| | | segments_result = model(audio_in=speech[sample_offset: sample_offset + step], |
| | | param_dict=param_dict) |
| | | if segments_result: |
| | | print(segments_result) |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from .paraformer_bin import Paraformer |
| | | from .vad_bin import Fsmn_vad |
| | | from .vad_bin import Fsmn_vad_online |
| | | from .punc_bin import CT_Transformer |
| | | from .punc_bin import CT_Transformer_VadRealtime |
| | |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | plot_timestamp_to: str = "", |
| | | pred_bias: int = 1, |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4, |
| | | ): |
| | |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = batch_size |
| | | self.plot_timestamp_to = plot_timestamp_to |
| | | self.pred_bias = pred_bias |
| | | if "predictor_bias" in config['model_conf'].keys(): |
| | | self.pred_bias = config['model_conf']['predictor_bias'] |
| | | else: |
| | | self.pred_bias = 0 |
| | | |
| | | def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: |
| | | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import os.path |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | | import numpy as np |
| | | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words) |
| | | logging = get_logger() |
| | | |
| | | |
| | | class CT_Transformer(): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4 |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | config_file = os.path.join(model_dir, 'punc.yaml') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = 1 |
| | | self.punc_list = config['punc_list'] |
| | | self.period = 0 |
| | | for i in range(len(self.punc_list)): |
| | | if self.punc_list[i] == ",": |
| | | self.punc_list[i] = "," |
| | | elif self.punc_list[i] == "?": |
| | | self.punc_list[i] = "?" |
| | | elif self.punc_list[i] == "。": |
| | | self.period = i |
| | | |
| | | def __call__(self, text: Union[list, str], split_size=20): |
| | | split_text = code_mix_split_words(text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |
| | | assert len(mini_sentences) == len(mini_sentences_id) |
| | | cache_sent = [] |
| | | cache_sent_id = [] |
| | | new_mini_sentence = "" |
| | | new_mini_sentence_punc = [] |
| | | cache_pop_trigger_limit = 200 |
| | | for mini_sentence_i in range(len(mini_sentences)): |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64') |
| | | data = { |
| | | "text": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), |
| | | } |
| | | try: |
| | | outputs = self.infer(data['text'], data['text_lengths']) |
| | | y = outputs[0] |
| | | punctuations = np.argmax(y,axis=-1)[0] |
| | | assert punctuations.size == len(mini_sentence) |
| | | except ONNXRuntimeError: |
| | | logging.warning("error") |
| | | |
| | | # Search for the last Period/QuestionMark as cache |
| | | if mini_sentence_i < len(mini_sentences) - 1: |
| | | sentenceEnd = -1 |
| | | last_comma_index = -1 |
| | | for i in range(len(punctuations) - 2, 1, -1): |
| | | if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": |
| | | last_comma_index = i |
| | | |
| | | if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: |
| | | # The sentence it too long, cut off at a comma. |
| | | sentenceEnd = last_comma_index |
| | | punctuations[sentenceEnd] = self.period |
| | | cache_sent = mini_sentence[sentenceEnd + 1:] |
| | | cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist() |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | | |
| | | new_mini_sentence_punc += [int(x) for x in punctuations] |
| | | words_with_punc = [] |
| | | for i in range(len(mini_sentence)): |
| | | if i > 0: |
| | | if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1: |
| | | mini_sentence[i] = " " + mini_sentence[i] |
| | | words_with_punc.append(mini_sentence[i]) |
| | | if self.punc_list[punctuations[i]] != "_": |
| | | words_with_punc.append(self.punc_list[punctuations[i]]) |
| | | new_mini_sentence += "".join(words_with_punc) |
| | | # Add Period for the end of the sentence |
| | | new_mini_sentence_out = new_mini_sentence |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc |
| | | if mini_sentence_i == len(mini_sentences) - 1: |
| | | if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": |
| | | new_mini_sentence_out = new_mini_sentence[:-1] + "。" |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] |
| | | elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?": |
| | | new_mini_sentence_out = new_mini_sentence + "。" |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] |
| | | return new_mini_sentence_out, new_mini_sentence_punc_out |
| | | |
| | | def infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer([feats, feats_len]) |
| | | return outputs |
| | | |
| | | |
| | | class CT_Transformer_VadRealtime(CT_Transformer): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4 |
| | | ): |
| | | super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads) |
| | | |
| | | def __call__(self, text: str, param_dict: map, split_size=20): |
| | | cache_key = "cache" |
| | | assert cache_key in param_dict |
| | | cache = param_dict[cache_key] |
| | | if cache is not None and len(cache) > 0: |
| | | precache = "".join(cache) |
| | | else: |
| | | precache = "" |
| | | cache = [] |
| | | full_text = precache + text |
| | | split_text = code_mix_split_words(full_text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |
| | | new_mini_sentence_punc = [] |
| | | assert len(mini_sentences) == len(mini_sentences_id) |
| | | |
| | | cache_sent = [] |
| | | cache_sent_id = np.array([], dtype='int32') |
| | | sentence_punc_list = [] |
| | | sentence_words_list = [] |
| | | cache_pop_trigger_limit = 200 |
| | | skip_num = 0 |
| | | for mini_sentence_i in range(len(mini_sentences)): |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) |
| | | text_length = len(mini_sentence_id) |
| | | data = { |
| | | "input": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([text_length], dtype='int32'), |
| | | "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32), |
| | | "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | } |
| | | try: |
| | | outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"]) |
| | | y = outputs[0] |
| | | punctuations = np.argmax(y,axis=-1)[0] |
| | | assert punctuations.size == len(mini_sentence) |
| | | except ONNXRuntimeError: |
| | | logging.warning("error") |
| | | |
| | | # Search for the last Period/QuestionMark as cache |
| | | if mini_sentence_i < len(mini_sentences) - 1: |
| | | sentenceEnd = -1 |
| | | last_comma_index = -1 |
| | | for i in range(len(punctuations) - 2, 1, -1): |
| | | if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": |
| | | last_comma_index = i |
| | | |
| | | if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: |
| | | # The sentence it too long, cut off at a comma. |
| | | sentenceEnd = last_comma_index |
| | | punctuations[sentenceEnd] = self.period |
| | | cache_sent = mini_sentence[sentenceEnd + 1:] |
| | | cache_sent_id = mini_sentence_id[sentenceEnd + 1:] |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | | |
| | | punctuations_np = [int(x) for x in punctuations] |
| | | new_mini_sentence_punc += punctuations_np |
| | | sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np] |
| | | sentence_words_list += mini_sentence |
| | | |
| | | assert len(sentence_punc_list) == len(sentence_words_list) |
| | | words_with_punc = [] |
| | | sentence_punc_list_out = [] |
| | | for i in range(0, len(sentence_words_list)): |
| | | if i > 0: |
| | | if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1: |
| | | sentence_words_list[i] = " " + sentence_words_list[i] |
| | | if skip_num < len(cache): |
| | | skip_num += 1 |
| | | else: |
| | | words_with_punc.append(sentence_words_list[i]) |
| | | if skip_num >= len(cache): |
| | | sentence_punc_list_out.append(sentence_punc_list[i]) |
| | | if sentence_punc_list[i] != "_": |
| | | words_with_punc.append(sentence_punc_list[i]) |
| | | sentence_out = "".join(words_with_punc) |
| | | |
| | | sentenceEnd = -1 |
| | | for i in range(len(sentence_punc_list) - 2, 1, -1): |
| | | if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | cache_out = sentence_words_list[sentenceEnd + 1:] |
| | | if sentence_out[-1] in self.punc_list: |
| | | sentence_out = sentence_out[:-1] |
| | | sentence_punc_list_out[-1] = "_" |
| | | param_dict[cache_key] = cache_out |
| | | return sentence_out, sentence_punc_list_out, cache_out |
| | | |
| | | def vad_mask(self, size, vad_pos, dtype=np.bool): |
| | | """Create mask for decoder self-attention. |
| | | |
| | | :param int size: size of mask |
| | | :param int vad_pos: index of vad index |
| | | :param torch.dtype dtype: result dtype |
| | | :rtype: torch.Tensor (B, Lmax, Lmax) |
| | | """ |
| | | ret = np.ones((size, size), dtype=dtype) |
| | | if vad_pos <= 0 or vad_pos >= size: |
| | | return ret |
| | | sub_corner = np.zeros( |
| | | (vad_pos - 1, size - vad_pos), dtype=dtype) |
| | | ret[0:vad_pos - 1, vad_pos:] = sub_corner |
| | | return ret |
| | | |
| | | def infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray, |
| | | vad_mask: np.ndarray, |
| | | sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks]) |
| | | return outputs |
| | | |
| New file |
| | |
| | | from enum import Enum |
| | | from typing import List, Tuple, Dict, Any |
| | | |
| | | import math |
| | | import numpy as np |
| | | |
| | | class VadStateMachine(Enum): |
| | | kVadInStateStartPointNotDetected = 1 |
| | | kVadInStateInSpeechSegment = 2 |
| | | kVadInStateEndPointDetected = 3 |
| | | |
| | | |
| | | class FrameState(Enum): |
| | | kFrameStateInvalid = -1 |
| | | kFrameStateSpeech = 1 |
| | | kFrameStateSil = 0 |
| | | |
| | | |
| | | # final voice/unvoice state per frame |
| | | class AudioChangeState(Enum): |
| | | kChangeStateSpeech2Speech = 0 |
| | | kChangeStateSpeech2Sil = 1 |
| | | kChangeStateSil2Sil = 2 |
| | | kChangeStateSil2Speech = 3 |
| | | kChangeStateNoBegin = 4 |
| | | kChangeStateInvalid = 5 |
| | | |
| | | |
| | | class VadDetectMode(Enum): |
| | | kVadSingleUtteranceDetectMode = 0 |
| | | kVadMutipleUtteranceDetectMode = 1 |
| | | |
| | | |
| | | class VADXOptions: |
| | | def __init__( |
| | | self, |
| | | sample_rate: int = 16000, |
| | | detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, |
| | | snr_mode: int = 0, |
| | | max_end_silence_time: int = 800, |
| | | max_start_silence_time: int = 3000, |
| | | do_start_point_detection: bool = True, |
| | | do_end_point_detection: bool = True, |
| | | window_size_ms: int = 200, |
| | | sil_to_speech_time_thres: int = 150, |
| | | speech_to_sil_time_thres: int = 150, |
| | | speech_2_noise_ratio: float = 1.0, |
| | | do_extend: int = 1, |
| | | lookback_time_start_point: int = 200, |
| | | lookahead_time_end_point: int = 100, |
| | | max_single_segment_time: int = 60000, |
| | | nn_eval_block_size: int = 8, |
| | | dcd_block_size: int = 4, |
| | | snr_thres: int = -100.0, |
| | | noise_frame_num_used_for_snr: int = 100, |
| | | decibel_thres: int = -100.0, |
| | | speech_noise_thres: float = 0.6, |
| | | fe_prior_thres: float = 1e-4, |
| | | silence_pdf_num: int = 1, |
| | | sil_pdf_ids: List[int] = [0], |
| | | speech_noise_thresh_low: float = -0.1, |
| | | speech_noise_thresh_high: float = 0.3, |
| | | output_frame_probs: bool = False, |
| | | frame_in_ms: int = 10, |
| | | frame_length_ms: int = 25, |
| | | ): |
| | | self.sample_rate = sample_rate |
| | | self.detect_mode = detect_mode |
| | | self.snr_mode = snr_mode |
| | | self.max_end_silence_time = max_end_silence_time |
| | | self.max_start_silence_time = max_start_silence_time |
| | | self.do_start_point_detection = do_start_point_detection |
| | | self.do_end_point_detection = do_end_point_detection |
| | | self.window_size_ms = window_size_ms |
| | | self.sil_to_speech_time_thres = sil_to_speech_time_thres |
| | | self.speech_to_sil_time_thres = speech_to_sil_time_thres |
| | | self.speech_2_noise_ratio = speech_2_noise_ratio |
| | | self.do_extend = do_extend |
| | | self.lookback_time_start_point = lookback_time_start_point |
| | | self.lookahead_time_end_point = lookahead_time_end_point |
| | | self.max_single_segment_time = max_single_segment_time |
| | | self.nn_eval_block_size = nn_eval_block_size |
| | | self.dcd_block_size = dcd_block_size |
| | | self.snr_thres = snr_thres |
| | | self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr |
| | | self.decibel_thres = decibel_thres |
| | | self.speech_noise_thres = speech_noise_thres |
| | | self.fe_prior_thres = fe_prior_thres |
| | | self.silence_pdf_num = silence_pdf_num |
| | | self.sil_pdf_ids = sil_pdf_ids |
| | | self.speech_noise_thresh_low = speech_noise_thresh_low |
| | | self.speech_noise_thresh_high = speech_noise_thresh_high |
| | | self.output_frame_probs = output_frame_probs |
| | | self.frame_in_ms = frame_in_ms |
| | | self.frame_length_ms = frame_length_ms |
| | | |
| | | |
| | | class E2EVadSpeechBufWithDoa(object): |
| | | def __init__(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | | self.buffer = [] |
| | | self.contain_seg_start_point = False |
| | | self.contain_seg_end_point = False |
| | | self.doa = 0 |
| | | |
| | | def Reset(self): |
| | | self.start_ms = 0 |
| | | self.end_ms = 0 |
| | | self.buffer = [] |
| | | self.contain_seg_start_point = False |
| | | self.contain_seg_end_point = False |
| | | self.doa = 0 |
| | | |
| | | |
| | | class E2EVadFrameProb(object): |
| | | def __init__(self): |
| | | self.noise_prob = 0.0 |
| | | self.speech_prob = 0.0 |
| | | self.score = 0.0 |
| | | self.frame_id = 0 |
| | | self.frm_state = 0 |
| | | |
| | | |
| | | class WindowDetector(object): |
| | | def __init__(self, window_size_ms: int, sil_to_speech_time: int, |
| | | speech_to_sil_time: int, frame_size_ms: int): |
| | | self.window_size_ms = window_size_ms |
| | | self.sil_to_speech_time = sil_to_speech_time |
| | | self.speech_to_sil_time = speech_to_sil_time |
| | | self.frame_size_ms = frame_size_ms |
| | | |
| | | self.win_size_frame = int(window_size_ms / frame_size_ms) |
| | | self.win_sum = 0 |
| | | self.win_state = [0] * self.win_size_frame # 初始化窗 |
| | | |
| | | self.cur_win_pos = 0 |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | self.cur_frame_state = FrameState.kFrameStateSil |
| | | self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) |
| | | self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) |
| | | |
| | | self.voice_last_frame_count = 0 |
| | | self.noise_last_frame_count = 0 |
| | | self.hydre_frame_count = 0 |
| | | |
| | | def Reset(self) -> None: |
| | | self.cur_win_pos = 0 |
| | | self.win_sum = 0 |
| | | self.win_state = [0] * self.win_size_frame |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | self.cur_frame_state = FrameState.kFrameStateSil |
| | | self.voice_last_frame_count = 0 |
| | | self.noise_last_frame_count = 0 |
| | | self.hydre_frame_count = 0 |
| | | |
| | | def GetWinSize(self) -> int: |
| | | return int(self.win_size_frame) |
| | | |
| | | def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: |
| | | cur_frame_state = FrameState.kFrameStateSil |
| | | if frameState == FrameState.kFrameStateSpeech: |
| | | cur_frame_state = 1 |
| | | elif frameState == FrameState.kFrameStateSil: |
| | | cur_frame_state = 0 |
| | | else: |
| | | return AudioChangeState.kChangeStateInvalid |
| | | self.win_sum -= self.win_state[self.cur_win_pos] |
| | | self.win_sum += cur_frame_state |
| | | self.win_state[self.cur_win_pos] = cur_frame_state |
| | | self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame |
| | | |
| | | if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: |
| | | self.pre_frame_state = FrameState.kFrameStateSpeech |
| | | return AudioChangeState.kChangeStateSil2Speech |
| | | |
| | | if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: |
| | | self.pre_frame_state = FrameState.kFrameStateSil |
| | | return AudioChangeState.kChangeStateSpeech2Sil |
| | | |
| | | if self.pre_frame_state == FrameState.kFrameStateSil: |
| | | return AudioChangeState.kChangeStateSil2Sil |
| | | if self.pre_frame_state == FrameState.kFrameStateSpeech: |
| | | return AudioChangeState.kChangeStateSpeech2Speech |
| | | return AudioChangeState.kChangeStateInvalid |
| | | |
| | | def FrameSizeMs(self) -> int: |
| | | return int(self.frame_size_ms) |
| | | |
| | | |
| | | class E2EVadModel(): |
| | | def __init__(self, vad_post_args: Dict[str, Any]): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | | self.vad_opts.sil_to_speech_time_thres, |
| | | self.vad_opts.speech_to_sil_time_thres, |
| | | self.vad_opts.frame_in_ms) |
| | | # self.encoder = encoder |
| | | # init variables |
| | | self.is_final = False |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | | self.lastest_confirmed_silence_frame = -1 |
| | | self.continous_silence_frame_count = 0 |
| | | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.number_end_time_detected = 0 |
| | | self.sil_frame = 0 |
| | | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | | self.noise_average_decibel = -100.0 |
| | | self.pre_end_silence_detected = False |
| | | self.next_seg = True |
| | | |
| | | self.output_data_buf = [] |
| | | self.output_data_buf_offset = 0 |
| | | self.frame_probs = [] |
| | | self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | | self.scores = None |
| | | self.max_time_out = False |
| | | self.decibel = [] |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | |
| | | def AllResetDetection(self): |
| | | self.is_final = False |
| | | self.data_buf_start_frame = 0 |
| | | self.frm_cnt = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | | self.lastest_confirmed_silence_frame = -1 |
| | | self.continous_silence_frame_count = 0 |
| | | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.number_end_time_detected = 0 |
| | | self.sil_frame = 0 |
| | | self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| | | self.noise_average_decibel = -100.0 |
| | | self.pre_end_silence_detected = False |
| | | self.next_seg = True |
| | | |
| | | self.output_data_buf = [] |
| | | self.output_data_buf_offset = 0 |
| | | self.frame_probs = [] |
| | | self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| | | self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| | | self.scores = None |
| | | self.max_time_out = False |
| | | self.decibel = [] |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | |
| | | def ResetDetection(self): |
| | | self.continous_silence_frame_count = 0 |
| | | self.latest_confirmed_speech_frame = 0 |
| | | self.lastest_confirmed_silence_frame = -1 |
| | | self.confirmed_start_frame = -1 |
| | | self.confirmed_end_frame = -1 |
| | | self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| | | self.windows_detector.Reset() |
| | | self.sil_frame = 0 |
| | | self.frame_probs = [] |
| | | |
| | | def ComputeDecibel(self) -> None: |
| | | frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) |
| | | frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | if self.data_buf_all is None: |
| | | self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] |
| | | self.data_buf = self.data_buf_all |
| | | else: |
| | | self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0])) |
| | | for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): |
| | | self.decibel.append( |
| | | 10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \ |
| | | 0.000001)) |
| | | |
| | | def ComputeScores(self, scores: np.ndarray) -> None: |
| | | # scores = self.encoder(feats, in_cache) # return B * T * D |
| | | self.vad_opts.nn_eval_block_size = scores.shape[1] |
| | | self.frm_cnt += scores.shape[1] # count total frames |
| | | if self.scores is None: |
| | | self.scores = scores # the first calculation |
| | | else: |
| | | self.scores = np.concatenate((self.scores, scores), axis=1) |
| | | |
| | | def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again |
| | | while self.data_buf_start_frame < frame_idx: |
| | | if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): |
| | | self.data_buf_start_frame += 1 |
| | | self.data_buf = self.data_buf_all[self.data_buf_start_frame * int( |
| | | self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] |
| | | |
| | | def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, |
| | | last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: |
| | | self.PopDataBufTillFrame(start_frm) |
| | | expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) |
| | | if last_frm_is_end_point: |
| | | extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ |
| | | self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) |
| | | expected_sample_number += int(extra_sample) |
| | | if end_point_is_sent_end: |
| | | expected_sample_number = max(expected_sample_number, len(self.data_buf)) |
| | | if len(self.data_buf) < expected_sample_number: |
| | | print('error in calling pop data_buf\n') |
| | | |
| | | if len(self.output_data_buf) == 0 or first_frm_is_start_point: |
| | | self.output_data_buf.append(E2EVadSpeechBufWithDoa()) |
| | | self.output_data_buf[-1].Reset() |
| | | self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms |
| | | self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms |
| | | self.output_data_buf[-1].doa = 0 |
| | | cur_seg = self.output_data_buf[-1] |
| | | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | | print('warning\n') |
| | | out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 |
| | | data_to_pop = 0 |
| | | if end_point_is_sent_end: |
| | | data_to_pop = expected_sample_number |
| | | else: |
| | | data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| | | if data_to_pop > len(self.data_buf): |
| | | print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') |
| | | data_to_pop = len(self.data_buf) |
| | | expected_sample_number = len(self.data_buf) |
| | | |
| | | cur_seg.doa = 0 |
| | | for sample_cpy_out in range(0, data_to_pop): |
| | | # cur_seg.buffer[out_pos ++] = data_buf_.back(); |
| | | out_pos += 1 |
| | | for sample_cpy_out in range(data_to_pop, expected_sample_number): |
| | | # cur_seg.buffer[out_pos++] = data_buf_.back() |
| | | out_pos += 1 |
| | | if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| | | print('Something wrong with the VAD algorithm\n') |
| | | self.data_buf_start_frame += frm_cnt |
| | | cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms |
| | | if first_frm_is_start_point: |
| | | cur_seg.contain_seg_start_point = True |
| | | if last_frm_is_end_point: |
| | | cur_seg.contain_seg_end_point = True |
| | | |
| | | def OnSilenceDetected(self, valid_frame: int): |
| | | self.lastest_confirmed_silence_frame = valid_frame |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataBufTillFrame(valid_frame) |
| | | # silence_detected_callback_ |
| | | # pass |
| | | |
| | | def OnVoiceDetected(self, valid_frame: int) -> None: |
| | | self.latest_confirmed_speech_frame = valid_frame |
| | | self.PopDataToOutputBuf(valid_frame, 1, False, False, False) |
| | | |
| | | def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: |
| | | if self.vad_opts.do_start_point_detection: |
| | | pass |
| | | if self.confirmed_start_frame != -1: |
| | | print('not reset vad properly\n') |
| | | else: |
| | | self.confirmed_start_frame = start_frame |
| | | |
| | | if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False) |
| | | |
| | | def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None: |
| | | for t in range(self.latest_confirmed_speech_frame + 1, end_frame): |
| | | self.OnVoiceDetected(t) |
| | | if self.vad_opts.do_end_point_detection: |
| | | pass |
| | | if self.confirmed_end_frame != -1: |
| | | print('not reset vad properly\n') |
| | | else: |
| | | self.confirmed_end_frame = end_frame |
| | | if not fake_result: |
| | | self.sil_frame = 0 |
| | | self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame) |
| | | self.number_end_time_detected += 1 |
| | | |
| | | def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None: |
| | | if is_final_frame: |
| | | self.OnVoiceEnd(cur_frm_idx, False, True) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | |
| | | def GetLatency(self) -> int: |
| | | return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms) |
| | | |
| | | def LatencyFrmNumAtStartPoint(self) -> int: |
| | | vad_latency = self.windows_detector.GetWinSize() |
| | | if self.vad_opts.do_extend: |
| | | vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) |
| | | return vad_latency |
| | | |
| | | def GetFrameState(self, t: int) -> FrameState: |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | cur_decibel = self.decibel[t] |
| | | cur_snr = cur_decibel - self.noise_average_decibel |
| | | # for each frame, calc log posterior probability of each state |
| | | if cur_decibel < self.vad_opts.decibel_thres: |
| | | frame_state = FrameState.kFrameStateSil |
| | | self.DetectOneFrame(frame_state, t, False) |
| | | return frame_state |
| | | |
| | | sum_score = 0.0 |
| | | noise_prob = 0.0 |
| | | assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num |
| | | if len(self.sil_pdf_ids) > 0: |
| | | assert len(self.scores) == 1 # 只支持batch_size = 1的测试 |
| | | sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] |
| | | sum_score = sum(sil_pdf_scores) |
| | | noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio |
| | | total_score = 1.0 |
| | | sum_score = total_score - sum_score |
| | | speech_prob = math.log(sum_score) |
| | | if self.vad_opts.output_frame_probs: |
| | | frame_prob = E2EVadFrameProb() |
| | | frame_prob.noise_prob = noise_prob |
| | | frame_prob.speech_prob = speech_prob |
| | | frame_prob.score = sum_score |
| | | frame_prob.frame_id = t |
| | | self.frame_probs.append(frame_prob) |
| | | if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: |
| | | if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: |
| | | frame_state = FrameState.kFrameStateSpeech |
| | | else: |
| | | frame_state = FrameState.kFrameStateSil |
| | | else: |
| | | frame_state = FrameState.kFrameStateSil |
| | | if self.noise_average_decibel < -99.9: |
| | | self.noise_average_decibel = cur_decibel |
| | | else: |
| | | self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * ( |
| | | self.vad_opts.noise_frame_num_used_for_snr |
| | | - 1)) / self.vad_opts.noise_frame_num_used_for_snr |
| | | |
| | | return frame_state |
| | | |
| | | def __call__(self, score: np.ndarray, waveform: np.ndarray, |
| | | is_final: bool = False, max_end_sil: int = 800, online: bool = False |
| | | ): |
| | | self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | self.ComputeScores(score) |
| | | if not is_final: |
| | | self.DetectCommonFrames() |
| | | else: |
| | | self.DetectLastFrames() |
| | | segments = [] |
| | | for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now |
| | | segment_batch = [] |
| | | if len(self.output_data_buf) > 0: |
| | | for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| | | if online: |
| | | if not self.output_data_buf[i].contain_seg_start_point: |
| | | continue |
| | | if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point: |
| | | continue |
| | | start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1 |
| | | if self.output_data_buf[i].contain_seg_end_point: |
| | | end_ms = self.output_data_buf[i].end_ms |
| | | self.next_seg = True |
| | | self.output_data_buf_offset += 1 |
| | | else: |
| | | end_ms = -1 |
| | | self.next_seg = False |
| | | else: |
| | | if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ |
| | | i].contain_seg_end_point): |
| | | continue |
| | | start_ms = self.output_data_buf[i].start_ms |
| | | end_ms = self.output_data_buf[i].end_ms |
| | | self.output_data_buf_offset += 1 |
| | | segment = [start_ms, end_ms] |
| | | segment_batch.append(segment) |
| | | |
| | | if segment_batch: |
| | | segments.append(segment_batch) |
| | | if is_final: |
| | | # reset class variables and clear the dict for the next query |
| | | self.AllResetDetection() |
| | | return segments |
| | | |
| | | def DetectCommonFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(self.frm_cnt - 1 - i) |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) |
| | | |
| | | return 0 |
| | | |
| | | def DetectLastFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | frame_state = self.GetFrameState(self.frm_cnt - 1 - i) |
| | | if i != 0: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) |
| | | else: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) |
| | | |
| | | return 0 |
| | | |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: |
| | | tmp_cur_frm_state = FrameState.kFrameStateInvalid |
| | | if cur_frm_state == FrameState.kFrameStateSpeech: |
| | | if math.fabs(1.0) > self.vad_opts.fe_prior_thres: |
| | | tmp_cur_frm_state = FrameState.kFrameStateSpeech |
| | | else: |
| | | tmp_cur_frm_state = FrameState.kFrameStateSil |
| | | elif cur_frm_state == FrameState.kFrameStateSil: |
| | | tmp_cur_frm_state = FrameState.kFrameStateSil |
| | | state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx) |
| | | frm_shift_in_ms = self.vad_opts.frame_in_ms |
| | | if AudioChangeState.kChangeStateSil2Speech == state_change: |
| | | silence_frame_count = self.continous_silence_frame_count |
| | | self.continous_silence_frame_count = 0 |
| | | self.pre_end_silence_detected = False |
| | | start_frame = 0 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()) |
| | | self.OnVoiceStart(start_frame) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment |
| | | for t in range(start_frame + 1, cur_frm_idx + 1): |
| | | self.OnVoiceDetected(t) |
| | | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): |
| | | self.OnVoiceDetected(t) |
| | | if cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif not is_final_frame: |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | elif AudioChangeState.kChangeStateSpeech2Sil == state_change: |
| | | self.continous_silence_frame_count = 0 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | pass |
| | | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif not is_final_frame: |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | elif AudioChangeState.kChangeStateSpeech2Speech == state_change: |
| | | self.continous_silence_frame_count = 0 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.max_time_out = True |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif not is_final_frame: |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | elif AudioChangeState.kChangeStateSil2Sil == state_change: |
| | | self.continous_silence_frame_count += 1 |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| | | # silence timeout, return zero length decision |
| | | if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( |
| | | self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ |
| | | or (is_final_frame and self.number_end_time_detected == 0): |
| | | for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx): |
| | | self.OnSilenceDetected(t) |
| | | self.OnVoiceStart(0, True) |
| | | self.OnVoiceEnd(0, True, False); |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | else: |
| | | if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(): |
| | | self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint()) |
| | | elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| | | if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh: |
| | | lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms) |
| | | if self.vad_opts.do_extend: |
| | | lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) |
| | | lookback_frame -= 1 |
| | | lookback_frame = max(0, lookback_frame) |
| | | self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif cur_frm_idx - self.confirmed_start_frame + 1 > \ |
| | | self.vad_opts.max_single_segment_time / frm_shift_in_ms: |
| | | self.OnVoiceEnd(cur_frm_idx, False, False) |
| | | self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| | | elif self.vad_opts.do_extend and not is_final_frame: |
| | | if self.continous_silence_frame_count <= int( |
| | | self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): |
| | | self.OnVoiceDetected(cur_frm_idx) |
| | | else: |
| | | self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) |
| | | else: |
| | | pass |
| | | |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ |
| | | self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: |
| | | self.ResetDetection() |
| | | |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from pathlib import Path |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | import copy |
| | | |
| | | import numpy as np |
| | | from typeguard import check_argument_types |
| | |
| | | cmvn = np.array([means, vars]) |
| | | return cmvn |
| | | |
| | | |
| | | class WavFrontendOnline(WavFrontend): |
| | | def __init__(self, **kwargs): |
| | | super().__init__(**kwargs) |
| | | # self.fbank_fn = knf.OnlineFbank(self.opts) |
| | | # add variables |
| | | self.frame_sample_length = int(self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000) |
| | | self.frame_shift_sample_length = int(self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000) |
| | | self.waveform = None |
| | | self.reserve_waveforms = None |
| | | self.input_cache = None |
| | | self.lfr_splice_cache = [] |
| | | |
| | | @staticmethod |
| | | # inputs has catted the cache |
| | | def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[ |
| | | np.ndarray, np.ndarray, int]: |
| | | """ |
| | | Apply lfr with data |
| | | """ |
| | | |
| | | LFR_inputs = [] |
| | | T = inputs.shape[0] # include the right context |
| | | T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2 |
| | | splice_idx = T_lfr |
| | | for i in range(T_lfr): |
| | | if lfr_m <= T - i * lfr_n: |
| | | LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1)) |
| | | else: # process last LFR frame |
| | | if is_final: |
| | | num_padding = lfr_m - (T - i * lfr_n) |
| | | frame = (inputs[i * lfr_n:]).reshape(-1) |
| | | for _ in range(num_padding): |
| | | frame = np.hstack((frame, inputs[-1])) |
| | | LFR_inputs.append(frame) |
| | | else: |
| | | # update splice_idx and break the circle |
| | | splice_idx = i |
| | | break |
| | | splice_idx = min(T - 1, splice_idx * lfr_n) |
| | | lfr_splice_cache = inputs[splice_idx:, :] |
| | | LFR_outputs = np.vstack(LFR_inputs) |
| | | return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx |
| | | |
| | | @staticmethod |
| | | def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int: |
| | | frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1) |
| | | return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0 |
| | | |
| | | |
| | | def fbank( |
| | | self, |
| | | input: np.ndarray, |
| | | input_lengths: np.ndarray |
| | | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| | | self.fbank_fn = knf.OnlineFbank(self.opts) |
| | | batch_size = input.shape[0] |
| | | if self.input_cache is None: |
| | | self.input_cache = np.empty((batch_size, 0), dtype=np.float32) |
| | | input = np.concatenate((self.input_cache, input), axis=1) |
| | | frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length) |
| | | # update self.in_cache |
| | | self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):] |
| | | waveforms = np.empty(0, dtype=np.int16) |
| | | feats_pad = np.empty(0, dtype=np.float32) |
| | | feats_lens = np.empty(0, dtype=np.int32) |
| | | if frame_num: |
| | | waveforms = [] |
| | | feats = [] |
| | | feats_lens = [] |
| | | for i in range(batch_size): |
| | | waveform = input[i] |
| | | waveforms.append( |
| | | waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)]) |
| | | waveform = waveform * (1 << 15) |
| | | |
| | | self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) |
| | | frames = self.fbank_fn.num_frames_ready |
| | | mat = np.empty([frames, self.opts.mel_opts.num_bins]) |
| | | for i in range(frames): |
| | | mat[i, :] = self.fbank_fn.get_frame(i) |
| | | feat = mat.astype(np.float32) |
| | | feat_len = np.array(mat.shape[0]).astype(np.int32) |
| | | feats.append(mat) |
| | | feats_lens.append(feat_len) |
| | | |
| | | waveforms = np.stack(waveforms) |
| | | feats_lens = np.array(feats_lens) |
| | | feats_pad = np.array(feats) |
| | | self.fbanks = feats_pad |
| | | self.fbanks_lens = copy.deepcopy(feats_lens) |
| | | return waveforms, feats_pad, feats_lens |
| | | |
| | | def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]: |
| | | return self.fbanks, self.fbanks_lens |
| | | |
| | | def lfr_cmvn( |
| | | self, |
| | | input: np.ndarray, |
| | | input_lengths: np.ndarray, |
| | | is_final: bool = False |
| | | ) -> Tuple[np.ndarray, np.ndarray, List[int]]: |
| | | batch_size = input.shape[0] |
| | | feats = [] |
| | | feats_lens = [] |
| | | lfr_splice_frame_idxs = [] |
| | | for i in range(batch_size): |
| | | mat = input[i, :input_lengths[i], :] |
| | | lfr_splice_frame_idx = -1 |
| | | if self.lfr_m != 1 or self.lfr_n != 1: |
| | | # update self.lfr_splice_cache in self.apply_lfr |
| | | mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, |
| | | is_final) |
| | | if self.cmvn_file is not None: |
| | | mat = self.apply_cmvn(mat) |
| | | feat_length = mat.shape[0] |
| | | feats.append(mat) |
| | | feats_lens.append(feat_length) |
| | | lfr_splice_frame_idxs.append(lfr_splice_frame_idx) |
| | | |
| | | feats_lens = np.array(feats_lens) |
| | | feats_pad = np.array(feats) |
| | | return feats_pad, feats_lens, lfr_splice_frame_idxs |
| | | |
| | | |
| | | def extract_fbank( |
| | | self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False |
| | | ) -> Tuple[np.ndarray, np.ndarray]: |
| | | batch_size = input.shape[0] |
| | | assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now' |
| | | waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D |
| | | if feats.shape[0]: |
| | | self.waveforms = waveforms if self.reserve_waveforms is None else np.concatenate( |
| | | (self.reserve_waveforms, waveforms), axis=1) |
| | | if not self.lfr_splice_cache: |
| | | for i in range(batch_size): |
| | | self.lfr_splice_cache.append(np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)) |
| | | |
| | | if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m: |
| | | lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D |
| | | feats = np.concatenate((lfr_splice_cache_np, feats), axis=1) |
| | | feats_lengths += lfr_splice_cache_np[0].shape[0] |
| | | frame_from_waveforms = int( |
| | | (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1) |
| | | minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0 |
| | | feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(feats, feats_lengths, is_final) |
| | | if self.lfr_m == 1: |
| | | self.reserve_waveforms = None |
| | | else: |
| | | reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame |
| | | # print('reserve_frame_idx: ' + str(reserve_frame_idx)) |
| | | # print('frame_frame: ' + str(frame_from_waveforms)) |
| | | self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length] |
| | | sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length |
| | | self.waveforms = self.waveforms[:, :sample_length] |
| | | else: |
| | | # update self.reserve_waveforms and self.lfr_splice_cache |
| | | self.reserve_waveforms = self.waveforms[:, |
| | | :-(self.frame_sample_length - self.frame_shift_sample_length)] |
| | | for i in range(batch_size): |
| | | self.lfr_splice_cache[i] = np.concatenate((self.lfr_splice_cache[i], feats[i]), axis=0) |
| | | return np.empty(0, dtype=np.float32), feats_lengths |
| | | else: |
| | | if is_final: |
| | | self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms |
| | | feats = np.stack(self.lfr_splice_cache) |
| | | feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1] |
| | | feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final) |
| | | if is_final: |
| | | self.cache_reset() |
| | | return feats, feats_lengths |
| | | |
| | | def get_waveforms(self): |
| | | return self.waveforms |
| | | |
| | | def cache_reset(self): |
| | | self.fbank_fn = knf.OnlineFbank(self.opts) |
| | | self.reserve_waveforms = None |
| | | self.input_cache = None |
| | | self.lfr_splice_cache = [] |
| | | |
| | | def load_bytes(input): |
| | | middle_data = np.frombuffer(input, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | |
| | | return feat, feat_len |
| | | |
| | | if __name__ == '__main__': |
| | | test() |
| | | test() |
| | |
| | | ): |
| | | check_argument_types() |
| | | |
| | | # self.token_list = self.load_token(token_path) |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | self.token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | # @staticmethod |
| | | # def load_token(file_path: Union[Path, str]) -> List: |
| | | # if not Path(file_path).exists(): |
| | | # raise TokenIDConverterError(f'The {file_path} does not exist.') |
| | | # |
| | | # with open(str(file_path), 'rb') as f: |
| | | # token_list = pickle.load(f) |
| | | # |
| | | # if len(token_list) != len(set(token_list)): |
| | | # raise TokenIDConverterError('The Token exists duplicated symbol.') |
| | | # return token_list |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | | token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | if self.unk_symbol not in token2id: |
| | | raise TokenIDConverterError( |
| | | f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | unk_id = token2id[self.unk_symbol] |
| | | return [token2id.get(i, unk_id) for i in tokens] |
| | | |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | |
| | | class CharTokenizer(): |
| | |
| | | input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: |
| | | input_dict = dict(zip(self.get_input_names(), input_content)) |
| | | try: |
| | | return self.session.run(None, input_dict) |
| | | return self.session.run(self.get_output_names(), input_dict) |
| | | except Exception as e: |
| | | raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e |
| | | |
| | |
| | | if not model_path.is_file(): |
| | | raise FileExistsError(f'{model_path} is not a file.') |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | | if len(words) <= word_limit: |
| | | return [words] |
| | | sentences = [] |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | return sentences |
| | | |
| | | def code_mix_split_words(text: str): |
| | | words = [] |
| | | segs = text.split() |
| | | for seg in segs: |
| | | # There is no space in seg. |
| | | current_word = "" |
| | | for c in seg: |
| | | if len(c.encode()) == 1: |
| | | # This is an ASCII char. |
| | | current_word += c |
| | | else: |
| | | # This is a Chinese char. |
| | | if len(current_word) > 0: |
| | | words.append(current_word) |
| | | current_word = "" |
| | | words.append(c) |
| | | if len(current_word) > 0: |
| | | words.append(current_word) |
| | | return words |
| | | |
| | | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | | if not Path(yaml_path).exists(): |
| | |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='rapdi_paraformer'): |
| | | def get_logger(name='funasr_onnx'): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| New file |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import os.path |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | | |
| | | import copy |
| | | import librosa |
| | | import numpy as np |
| | | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.frontend import WavFrontend, WavFrontendOnline |
| | | from .utils.e2e_vad import E2EVadModel |
| | | |
| | | logging = get_logger() |
| | | |
| | | |
| | | class Fsmn_vad(): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4, |
| | | max_end_sil: int = None, |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | config_file = os.path.join(model_dir, 'vad.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'vad.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.frontend = WavFrontend( |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = batch_size |
| | | self.vad_scorer = E2EVadModel(config["vad_post_conf"]) |
| | | self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"] |
| | | self.encoder_conf = config["encoder_conf"] |
| | | |
| | | def prepare_cache(self, in_cache: list = []): |
| | | if len(in_cache) > 0: |
| | | return in_cache |
| | | fsmn_layers = self.encoder_conf["fsmn_layers"] |
| | | proj_dim = self.encoder_conf["proj_dim"] |
| | | lorder = self.encoder_conf["lorder"] |
| | | for i in range(fsmn_layers): |
| | | cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32) |
| | | in_cache.append(cache) |
| | | return in_cache |
| | | |
| | | |
| | | def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List: |
| | | waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq) |
| | | waveform_nums = len(waveform_list) |
| | | is_final = kwargs.get('kwargs', False) |
| | | |
| | | segments = [[]] * self.batch_size |
| | | for beg_idx in range(0, waveform_nums, self.batch_size): |
| | | |
| | | end_idx = min(waveform_nums, beg_idx + self.batch_size) |
| | | waveform = waveform_list[beg_idx:end_idx] |
| | | feats, feats_len = self.extract_feat(waveform) |
| | | waveform = np.array(waveform) |
| | | param_dict = kwargs.get('param_dict', dict()) |
| | | in_cache = param_dict.get('in_cache', list()) |
| | | in_cache = self.prepare_cache(in_cache) |
| | | try: |
| | | t_offset = 0 |
| | | step = int(min(feats_len.max(), 6000)) |
| | | for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)): |
| | | if t_offset + step >= feats_len - 1: |
| | | step = feats_len - t_offset |
| | | is_final = True |
| | | else: |
| | | is_final = False |
| | | feats_package = feats[:, t_offset:int(t_offset + step), :] |
| | | waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)] |
| | | |
| | | inputs = [feats_package] |
| | | # inputs = [feats] |
| | | inputs.extend(in_cache) |
| | | scores, out_caches = self.infer(inputs) |
| | | in_cache = out_caches |
| | | segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False) |
| | | # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil) |
| | | |
| | | if segments_part: |
| | | for batch_num in range(0, self.batch_size): |
| | | segments[batch_num] += segments_part[batch_num] |
| | | |
| | | except ONNXRuntimeError: |
| | | # logging.warning(traceback.format_exc()) |
| | | logging.warning("input wav is silence or noise") |
| | | segments = '' |
| | | |
| | | return segments |
| | | |
| | | def load_data(self, |
| | | wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: |
| | | def load_wav(path: str) -> np.ndarray: |
| | | waveform, _ = librosa.load(path, sr=fs) |
| | | return waveform |
| | | |
| | | if isinstance(wav_content, np.ndarray): |
| | | return [wav_content] |
| | | |
| | | if isinstance(wav_content, str): |
| | | return [load_wav(wav_content)] |
| | | |
| | | if isinstance(wav_content, list): |
| | | return [load_wav(path) for path in wav_content] |
| | | |
| | | raise TypeError( |
| | | f'The type of {wav_content} is not in [str, np.ndarray, list]') |
| | | |
| | | def extract_feat(self, |
| | | waveform_list: List[np.ndarray] |
| | | ) -> Tuple[np.ndarray, np.ndarray]: |
| | | feats, feats_len = [], [] |
| | | for waveform in waveform_list: |
| | | speech, _ = self.frontend.fbank(waveform) |
| | | feat, feat_len = self.frontend.lfr_cmvn(speech) |
| | | feats.append(feat) |
| | | feats_len.append(feat_len) |
| | | |
| | | feats = self.pad_feats(feats, np.max(feats_len)) |
| | | feats_len = np.array(feats_len).astype(np.int32) |
| | | return feats, feats_len |
| | | |
| | | @staticmethod |
| | | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: |
| | | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: |
| | | pad_width = ((0, max_feat_len - cur_len), (0, 0)) |
| | | return np.pad(feat, pad_width, 'constant', constant_values=0) |
| | | |
| | | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] |
| | | feats = np.array(feat_res).astype(np.float32) |
| | | return feats |
| | | |
| | | def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: |
| | | |
| | | outputs = self.ort_infer(feats) |
| | | scores, out_caches = outputs[0], outputs[1:] |
| | | return scores, out_caches |
| | | |
| | | |
| | | class Fsmn_vad_online(): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Deep-FSMN for Large Vocabulary Continuous Speech Recognition |
| | | https://arxiv.org/abs/1803.05030 |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4, |
| | | max_end_sil: int = None, |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | if quantize: |
| | | model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | config_file = os.path.join(model_dir, 'vad.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'vad.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.frontend = WavFrontendOnline( |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = batch_size |
| | | self.vad_scorer = E2EVadModel(config["vad_post_conf"]) |
| | | self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"] |
| | | self.encoder_conf = config["encoder_conf"] |
| | | |
| | | def prepare_cache(self, in_cache: list = []): |
| | | if len(in_cache) > 0: |
| | | return in_cache |
| | | fsmn_layers = self.encoder_conf["fsmn_layers"] |
| | | proj_dim = self.encoder_conf["proj_dim"] |
| | | lorder = self.encoder_conf["lorder"] |
| | | for i in range(fsmn_layers): |
| | | cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32) |
| | | in_cache.append(cache) |
| | | return in_cache |
| | | |
| | | def __call__(self, audio_in: np.ndarray, **kwargs) -> List: |
| | | waveforms = np.expand_dims(audio_in, axis=0) |
| | | |
| | | param_dict = kwargs.get('param_dict', dict()) |
| | | is_final = param_dict.get('is_final', False) |
| | | feats, feats_len = self.extract_feat(waveforms, is_final) |
| | | segments = [] |
| | | if feats.size != 0: |
| | | in_cache = param_dict.get('in_cache', list()) |
| | | in_cache = self.prepare_cache(in_cache) |
| | | try: |
| | | inputs = [feats] |
| | | inputs.extend(in_cache) |
| | | scores, out_caches = self.infer(inputs) |
| | | param_dict['in_cache'] = out_caches |
| | | waveforms = self.frontend.get_waveforms() |
| | | segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil, |
| | | online=True) |
| | | |
| | | |
| | | except ONNXRuntimeError: |
| | | # logging.warning(traceback.format_exc()) |
| | | logging.warning("input wav is silence or noise") |
| | | segments = [] |
| | | return segments |
| | | |
| | | def load_data(self, |
| | | wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: |
| | | def load_wav(path: str) -> np.ndarray: |
| | | waveform, _ = librosa.load(path, sr=fs) |
| | | return waveform |
| | | |
| | | if isinstance(wav_content, np.ndarray): |
| | | return [wav_content] |
| | | |
| | | if isinstance(wav_content, str): |
| | | return [load_wav(wav_content)] |
| | | |
| | | if isinstance(wav_content, list): |
| | | return [load_wav(path) for path in wav_content] |
| | | |
| | | raise TypeError( |
| | | f'The type of {wav_content} is not in [str, np.ndarray, list]') |
| | | |
| | | def extract_feat(self, |
| | | waveforms: np.ndarray, is_final: bool = False |
| | | ) -> Tuple[np.ndarray, np.ndarray]: |
| | | waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32) |
| | | for idx, waveform in enumerate(waveforms): |
| | | waveforms_lens[idx] = waveform.shape[-1] |
| | | |
| | | feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final) |
| | | # feats.append(feat) |
| | | # feats_len.append(feat_len) |
| | | |
| | | # feats = self.pad_feats(feats, np.max(feats_len)) |
| | | # feats_len = np.array(feats_len).astype(np.int32) |
| | | return feats.astype(np.float32), feats_len.astype(np.int32) |
| | | |
| | | @staticmethod |
| | | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: |
| | | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: |
| | | pad_width = ((0, max_feat_len - cur_len), (0, 0)) |
| | | return np.pad(feat, pad_width, 'constant', constant_values=0) |
| | | |
| | | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] |
| | | feats = np.array(feat_res).astype(np.float32) |
| | | return feats |
| | | |
| | | def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: |
| | | |
| | | outputs = self.ort_infer(feats) |
| | | scores, out_caches = outputs[0], outputs[1:] |
| | | return scores, out_caches |
| | | |
| | |
| | | |
| | | |
| | | MODULE_NAME = 'funasr_onnx' |
| | | VERSION_NUM = '0.0.2' |
| | | VERSION_NUM = '0.0.5' |
| | | |
| | | setuptools.setup( |
| | | name=MODULE_NAME, |
| | | version=VERSION_NUM, |
| | | platforms="Any", |
| | | url="https://github.com/alibaba-damo-academy/FunASR.git", |
| | | author="Speech Lab, Alibaba Group, China", |
| | | author="Speech Lab of DAMO Academy, Alibaba Group", |
| | | author_email="funasr@list.alibaba-inc.com", |
| | | description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", |
| | | license='MIT', |
| New file |
| | |
| | | |
| | | import time |
| | | import sys |
| | | import librosa |
| | | from funasr.utils.types import str2bool |
| | | |
| | | import argparse |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument('--model_dir', type=str, required=True) |
| | | parser.add_argument('--backend', type=str, default='onnx', help='["onnx", "torch"]') |
| | | parser.add_argument('--wav_file', type=str, default=None, help='amp fallback number') |
| | | parser.add_argument('--quantize', type=str2bool, default=False, help='quantized model') |
| | | parser.add_argument('--intra_op_num_threads', type=int, default=1, help='intra_op_num_threads for onnx') |
| | | parser.add_argument('--batch_size', type=int, default=1, help='batch_size for onnx') |
| | | args = parser.parse_args() |
| | | |
| | | |
| | | from funasr.runtime.python.libtorch.funasr_torch import Paraformer |
| | | if args.backend == "onnx": |
| | | from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer |
| | | |
| | | model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads) |
| | | |
| | | wav_file_f = open(args.wav_file, 'r') |
| | | wav_files = wav_file_f.readlines() |
| | | |
| | | # warm-up |
| | | total = 0.0 |
| | | num = 30 |
| | | wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip() |
| | | for i in range(num): |
| | | beg_time = time.time() |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = end_time-beg_time |
| | | total += duration |
| | | print(result) |
| | | print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53)) |
| | | |
| | | # infer time |
| | | wav_path = [] |
| | | beg_time = time.time() |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | wav_path += [wav_path_i] |
| | | result = model(wav_path) |
| | | end_time = time.time() |
| | | duration = (end_time-beg_time)*1000 |
| | | print("total_time_comput_ms: {}".format(int(duration))) |
| | | |
| | | duration_time = 0.0 |
| | | for i, wav_path_i in enumerate(wav_files): |
| | | wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip() |
| | | waveform, _ = librosa.load(wav_path, sr=16000) |
| | | duration_time += len(waveform)/16.0 |
| | | print("total_time_wav_ms: {}".format(int(duration_time))) |
| | | |
| | | print("total_rtf: {:.5}".format(duration/duration_time)) |
| | |
| | | default=sys.maxsize, |
| | | help="The maximum number update step to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | group.add_argument( |
| | | "--patience", |
| | | type=int_or_none, |
| | |
| | | ) -> AbsIterFactory: |
| | | assert check_argument_types() |
| | | |
| | | if hasattr(args, "frontend_conf"): |
| | | if args.frontend_conf is not None and "fs" in args.frontend_conf: |
| | | dest_sample_rate = args.frontend_conf["fs"] |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | else: |
| | | dest_sample_rate = 16000 |
| | | |
| | | dataset = ESPnetDataset( |
| | | iter_options.data_path_and_name_and_type, |
| | | float_dtype=args.train_dtype, |
| | | preprocess=iter_options.preprocess_fn, |
| | | max_cache_size=iter_options.max_cache_size, |
| | | max_cache_fd=iter_options.max_cache_fd, |
| | | dest_sample_rate=args.frontend_conf["fs"], |
| | | dest_sample_rate=dest_sample_rate, |
| | | ) |
| | | cls.check_task_requirements( |
| | | dataset, args.allow_variable_data_keys, train=iter_options.train |
| | |
| | | default="13_15", |
| | | help="The range of noise decibel level.", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | |
| | | for class_choices in cls.class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.lm.abs_model import AbsLM |
| | | from funasr.lm.espnet_model import ESPnetLanguageModel |
| | | from funasr.lm.abs_model import LanguageModel |
| | | from funasr.lm.seq_rnn_lm import SequentialRNNLM |
| | | from funasr.lm.transformer_lm import TransformerLM |
| | | from funasr.tasks.abs_task import AbsTask |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetLanguageModel), |
| | | default=get_default_kwargs(LanguageModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: |
| | | def build_model(cls, args: argparse.Namespace) -> LanguageModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | |
| | | # 2. Build ESPnetModel |
| | | # Assume the last-id is sos_and_eos |
| | | model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) |
| | | |
| | | # 3. Initialize |
| | | if args.init is not None: |
| | |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.punctuation.espnet_model import ESPnetPunctuationModel |
| | | from funasr.punctuation.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.train.abs_model import AbsPunctuation |
| | | from funasr.train.abs_model import PunctuationModel |
| | | from funasr.models.target_delay_transformer import TargetDelayTransformer |
| | | from funasr.models.vad_realtime_transformer import VadRealtimeTransformer |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(ESPnetPunctuationModel), |
| | | default=get_default_kwargs(PunctuationModel), |
| | | help="The keyword arguments for model class.", |
| | | ) |
| | | |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel: |
| | | def build_model(cls, args: argparse.Namespace) -> PunctuationModel: |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | |
| | | # Assume the last-id is sos_and_eos |
| | | if "punc_weight" in args.model_conf: |
| | | args.model_conf.pop("punc_weight") |
| | | model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf) |
| | | |
| | | # FIXME(kamo): Should be done in model? |
| | | # 3. Initialize |
| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | """ |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | |
| | | s3prl=S3prlFrontend, |
| | | fused=FusedFrontends, |
| | | wav_frontend=WavFrontend, |
| | | wav_frontend_online=WavFrontendOnline, |
| | | ), |
| | | type_check=AbsFrontend, |
| | | default="default", |
| | |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("e2evad") |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf) |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend) |
| | | |
| | | return model |
| | | |
| | |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | cmvn_file: Union[Path, str] = None, |
| | | ): |
| | | """Build model from the files. |
| | | |
| | |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | #if cmvn_file is not None: |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | model.to(device) |
| File was renamed from funasr/punctuation/espnet_model.py |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | |
| | | |
| | | from typing import Dict |
| | | from typing import Optional |
| | | from typing import Tuple |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.punctuation.abs_model import AbsPunctuation |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | from funasr.modules.scorers.scorer_interface import BatchScorerInterface |
| | | |
| | | class ESPnetPunctuationModel(AbsESPnetModel): |
| | | |
| | | class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC): |
| | | """The abstract class |
| | | |
| | | To share the loss calculation way among different models, |
| | | We uses delegate pattern here: |
| | | The instance of this class should be passed to "LanguageModel" |
| | | |
| | | This "model" is one of mediator objects for "Task" class. |
| | | |
| | | """ |
| | | |
| | | @abstractmethod |
| | | def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def with_vad(self) -> bool: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class PunctuationModel(AbsESPnetModel): |
| | | |
| | | def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | |
| | | self.punc_weight = torch.Tensor(punc_weight) |
| | | self.sos = 1 |
| | | self.eos = 2 |
| | | |
| | | |
| | | # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. |
| | | self.ignore_id = ignore_id |
| | | #if self.punc_model.with_vad(): |
| | | # if self.punc_model.with_vad(): |
| | | # print("This is a vad puncuation model.") |
| | | |
| | | |
| | | def nll( |
| | | self, |
| | | text: torch.Tensor, |
| | |
| | | else: |
| | | text = text[:, :max_length] |
| | | punc = punc[:, :max_length] |
| | | |
| | | |
| | | if self.punc_model.with_vad(): |
| | | # Should be VadRealtimeTransformer |
| | | assert vad_indexes is not None |
| | |
| | | else: |
| | | # Should be TargetDelayTransformer, |
| | | y, _ = self.punc_model(text, text_lengths) |
| | | |
| | | |
| | | # Calc negative log likelihood |
| | | # nll: (BxL,) |
| | | if self.training == False: |
| | |
| | | return nll, text_lengths |
| | | else: |
| | | self.punc_weight = self.punc_weight.to(punc.device) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id) |
| | | nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", |
| | | ignore_index=self.ignore_id) |
| | | # nll: (BxL,) -> (BxL,) |
| | | if max_length is None: |
| | | nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0) |
| | |
| | | # nll: (BxL,) -> (B, L) |
| | | nll = nll.view(batch_size, -1) |
| | | return nll, text_lengths |
| | | |
| | | |
| | | def batchify_nll(self, |
| | | text: torch.Tensor, |
| | | punc: torch.Tensor, |
| | |
| | | nlls = [] |
| | | x_lengths = [] |
| | | max_length = text_lengths.max() |
| | | |
| | | |
| | | start_idx = 0 |
| | | while True: |
| | | end_idx = min(start_idx + batch_size, total_num) |
| | |
| | | assert nll.size(0) == total_num |
| | | assert x_lengths.size(0) == total_num |
| | | return nll, x_lengths |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | text: torch.Tensor, |
| | |
| | | ntokens = y_lengths.sum() |
| | | loss = nll.sum() / ntokens |
| | | stats = dict(loss=loss.detach()) |
| | | |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | |
| | | def collect_feats(self, text: torch.Tensor, punc: torch.Tensor, |
| | | text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | | return {} |
| | | |
| | | |
| | | def inference(self, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | |
| | | if out_item['wrong'] > 0: |
| | | rst['wrong_sentences'] += 1 |
| | | cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') |
| | | cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n') |
| | | cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n') |
| | | cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n') |
| | | cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n') |
| | | |
| | | if rst['Wrd'] > 0: |
| | | rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) |
| | |
| | | name="funasr", |
| | | version=version, |
| | | url="https://github.com/alibaba-damo-academy/FunASR.git", |
| | | author="Speech Lab, Alibaba Group, China", |
| | | author="Speech Lab of DAMO Academy, Alibaba Group", |
| | | author_email="funasr@list.alibaba-inc.com", |
| | | description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", |
| | | long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(), |
| | |
| | | rec_result = inference_pipeline( |
| | | audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav') |
| | | logger.info("asr inference result: {0}".format(rec_result)) |
| | | assert rec_result["text"] == "每一天都要快乐喔" |
| | | |
| | | def test_paraformer(self): |
| | | inference_pipeline = pipeline( |
| | |
| | | rec_result = inference_pipeline( |
| | | audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav') |
| | | logger.info("asr inference result: {0}".format(rec_result)) |
| | | assert rec_result["text"] == "每一天都要快乐喔" |
| | | |
| | | |
| | | class TestMfccaInferencePipelines(unittest.TestCase): |
| | |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.punctuation, |
| | | model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727', |
| | | model_revision="v1.0.0", |
| | | ) |
| | | inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益" |
| | | vads = inputs.split("|") |
| | | cache_out = [] |
| | | rec_result_all = "outputs:" |
| | | param_dict = {"cache": []} |
| | | for vad in vads: |
| | | rec_result = inference_pipeline(text_in=vad, cache=cache_out) |
| | | cache_out = rec_result['cache'] |
| | | rec_result_all += rec_result['text'] |
| | | rec_result = inference_pipeline(text_in=vad, param_dict=param_dict) |
| | | rec_result_all += rec_result["text"] |
| | | logger.info("punctuation inference result: {0}".format(rec_result_all)) |
| | | |
| | | |