From 19f416f5fc916d985b91733a0fd6271517b9fe0f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 二月 2023 19:33:36 +0800
Subject: [PATCH] Merge pull request #96 from alibaba-damo-academy/main
---
egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py | 132
funasr/bin/lm_inference_launch.py | 130
setup.py | 38
egs/alimeeting/diarization/sond/infer_alimeeting_test.py | 24
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md | 25
funasr/export/export_model.py | 63
funasr/models/encoder/opennmt_encoders/__init__.py | 0
funasr/bin/asr_inference_paraformer.py | 61
funasr/bin/asr_inference_uniasr_vad.py | 2
docs/installation.md | 2
funasr/models/encoder/opennmt_encoders/conv_encoder.py | 277 +
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py | 88
funasr/bin/asr_inference_paraformer_vad_punc.py | 38
funasr/models/encoder/resnet34_encoder.py | 477 ++
funasr/bin/sond_inference.py | 544 ++
funasr/bin/sv_inference_launch.py | 4
funasr/bin/asr_inference_paraformer_vad.py | 19
funasr/models/pooling/statistic_pooling.py | 59
funasr/tasks/diar.py | 585 ++
funasr/tasks/lm.py | 2
funasr/tasks/abs_task.py | 47
funasr/models/encoder/opennmt_encoders/ci_scorers.py | 38
funasr/models/encoder/opennmt_encoders/self_attention_encoder.py | 480 ++
funasr/utils/postprocess_utils.py | 6
egs/alimeeting/diarization/sond/config_fbank.yaml | 2728 ++++++++++++
README.md | 16
funasr/utils/misc.py | 48
funasr/bin/asr_inference_uniasr.py | 2
funasr/bin/asr_inference.py | 2
funasr/utils/timestamp_tools.py | 57
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md | 53
funasr/bin/asr_inference_paraformer_timestamp.py | 2
funasr/models/e2e_diar_sond.py | 402 +
funasr/version.txt | 2
docs_cn/installation.md | 5
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py | 54
funasr/modules/multi_layer_conv.py | 52
funasr/export/models/__init__.py | 84
funasr/modules/attention.py | 106
funasr/bin/lm_calc_perplexity.py | 3
funasr/datasets/preprocessor.py | 73
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py | 87
funasr/bin/lm_inference.py | 406 +
egs/alimeeting/diarization/sond/unit_test.py | 97
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py | 52
funasr/models/encoder/opennmt_encoders/fsmn_encoder.py | 335 +
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md | 53
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer_after_finetune.py | 54
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md | 23
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/finetune.py | 39
funasr/bin/lm_train.py | 50
funasr/models/e2e_asr_paraformer.py | 496 ++
egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py | 37
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py | 39
funasr/models/decoder/contextual_decoder.py | 776 +++
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md | 75
funasr/utils/job_runner.py | 103
funasr/export/README.md | 15
funasr/bin/diar_inference_launch.py | 179
funasr/tasks/asr.py | 8
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/README.md | 53
egs/alimeeting/diarization/sond/run.sh | 48
funasr/models/frontend/wav_frontend.py | 12
funasr/lm/espnet_model.py | 4
egs/alimeeting/diarization/sond/path.sh | 5
funasr/models/predictor/cif.py | 5
funasr/bin/sv_inference.py | 29
funasr/bin/tokenize_text.py | 283 +
egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py | 88
egs/alimeeting/diarization/sond/README.md | 6
egs/alimeeting/diarization/sond/config.yaml | 2740 ++++++++++++
egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py | 17
72 files changed, 12,718 insertions(+), 326 deletions(-)
diff --git a/README.md b/README.md
index 7b16f58..c759b79 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,17 @@
<strong>FunASR</strong> hopes to build a bridge between academic research and industrial applications on speech recognition. By supporting the training & finetuning of the industrial-grade speech recognition model released on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition), researchers and developers can conduct research and production of speech recognition models more conveniently, and promote the development of speech recognition ecology. ASR for Fun锛�
-## Release Notes:
+[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
+| [**Highlights**](#highlights)
+| [**Installation**](#installation)
+| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/index.html)
+| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
+| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
+| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
+| [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
+| [**Contact**](#contact)
+
+## What's new:
### 2023.1.16, funasr-0.1.6
- We release a new version model [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), which integrate the [VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) model, [ASR](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary),
[Punctuation](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) model and timestamp together. The model could take in several hours long inputs.
@@ -16,7 +26,7 @@
- We improve the pipeline of modelscope to speedup the inference, by integrating the process of build model into build pipeline.
- Various new types of audio input types are now supported by modelscope inference pipeline, including wav.scp, wav format, audio bytes, wave samples...
-## Key Features
+## Highlights
- Many types of typical models are supported, e.g., [Tranformer](https://arxiv.org/abs/1706.03762), [Conformer](https://arxiv.org/abs/2005.08100), [Paraformer](https://arxiv.org/abs/2206.08317).
- We have released large number of academic and industrial pretrained models on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition)
- The pretrained model [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) obtains the best performance on many tasks in [SpeechIO leaderboard](https://github.com/SpeechColab/Leaderboard)
@@ -75,4 +85,4 @@
booktitle={INTERSPEECH},
year={2022}
}
-```
+```
\ No newline at end of file
diff --git a/docs/installation.md b/docs/installation.md
index 1eb813b..61d06b5 100755
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -26,7 +26,7 @@
- Install ModelScope
``` sh
-pip install "modelscope[audio]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
```
- Install other packages
diff --git a/docs_cn/installation.md b/docs_cn/installation.md
index fc74780..a31bc01 100755
--- a/docs_cn/installation.md
+++ b/docs_cn/installation.md
@@ -5,6 +5,7 @@
``` sh
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
sh Miniconda3-latest-Linux-x86_64.sh
+source ~/.bashrc
conda create -n funasr python=3.7
conda activate funasr
```
@@ -12,7 +13,7 @@
- 瀹夎Pytorch (鐗堟湰 >= 1.7.0):
```sh
-pip install torch torchvision torchaudio
+pip install torch torchaudio
```
鍏充簬鏇村鐨勭増鏈�, 璇峰弬鐓� [https://pytorch.org/get-started/locally](https://pytorch.org/get-started/locally)
@@ -26,7 +27,7 @@
瀹夎鎴栨洿鏂癕odelScope
``` sh
-pip install "modelscope[audio]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
```
- 涓嬭浇FunASR浠撳簱锛屽苟瀹夎鍓╀綑鎵�闇�渚濊禆
diff --git a/egs/alimeeting/diarization/sond/README.md b/egs/alimeeting/diarization/sond/README.md
new file mode 100644
index 0000000..8bef142
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/README.md
@@ -0,0 +1,6 @@
+# Results
+You will get a DER about 4.21%, which is reported in [1], Table 6, line "SOND Oracle Profile".
+
+# Reference
+[1] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, Zhihao Du, Shiliang Zhang,
+Siqi Zheng, Zhijie Yan. EMNLP 2022.
\ No newline at end of file
diff --git a/egs/alimeeting/diarization/sond/config.yaml b/egs/alimeeting/diarization/sond/config.yaml
new file mode 100644
index 0000000..072c171
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/config.yaml
@@ -0,0 +1,2740 @@
+config: finetune.yaml
+print_config: false
+log_level: INFO
+dry_run: false
+iterator_type: sequence
+output_dir: exp/sond
+ngpu: 1
+seed: 0
+num_workers: 16
+num_att_plot: 0
+dist_backend: nccl
+dist_init_method: env://
+dist_world_size: null
+dist_rank: null
+local_rank: 0
+dist_master_addr: null
+dist_master_port: null
+dist_launcher: null
+multiprocessing_distributed: true
+distributed: false
+unused_parameters: true
+sharded_ddp: false
+ddp_backend: pytorch_ddp
+cudnn_enabled: true
+cudnn_benchmark: false
+cudnn_deterministic: true
+collect_stats: false
+write_collected_feats: false
+max_epoch: 50
+patience: null
+val_scheduler_criterion:
+- valid
+- acc
+early_stopping_criterion:
+- valid
+- loss
+- min
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+nbest_averaging_interval: 0
+grad_clip: 5
+grad_clip_type: 2.0
+grad_noise: false
+accum_grad: 1
+no_forward_run: false
+resume: true
+train_dtype: float32
+use_amp: false
+log_interval: 50
+use_matplotlib: false
+use_tensorboard: true
+use_wandb: false
+wandb_project: null
+wandb_id: null
+wandb_entity: null
+wandb_name: null
+wandb_model_log_interval: -1
+use_pai: true
+detect_anomaly: false
+pretrain_path: null
+init_param: []
+ignore_init_mismatch: false
+freeze_param: []
+num_iters_per_epoch: null
+batch_size: 20
+valid_batch_size: null
+batch_bins: 10000
+valid_batch_bins: null
+train_shape_file:
+- /data/volume1/youyan/aishell/ark/train/speech_shape.1
+- /data/volume1/youyan/aishell/ark/train/text_shape.1
+valid_shape_file:
+- /data/volume1/youyan/aishell/ark/dev/speech_shape.1
+- /data/volume1/youyan/aishell/ark/dev/text_shape.1
+batch_type: length
+valid_batch_type: null
+fold_length:
+- 512
+- 150
+sort_in_batch: descending
+sort_batch: descending
+multiple_iterator: false
+chunk_length: 500
+chunk_shift_ratio: 0.5
+num_cache_chunks: 1024
+train_data_path_and_name_and_type:
+- - /data/volume1/youyan/aishell/ark/train/data.scp
+ - speech
+ - kaldi_ark
+- - /data/volume1/youyan/aishell/ark/train/data.text.1
+ - text
+ - text
+valid_data_path_and_name_and_type:
+- - /data/volume1/youyan/aishell/ark/dev/data.scp
+ - speech
+ - kaldi_ark
+- - /data/volume1/youyan/aishell/ark/dev/data.text.1
+ - text
+ - text
+allow_variable_data_keys: false
+max_cache_size: 0.0
+max_cache_fd: 32
+valid_max_cache_size: null
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+token_list:
+- '0'
+- '1'
+- '2'
+- '3'
+- '4'
+- '5'
+- '6'
+- '7'
+- '8'
+- '9'
+- '10'
+- '11'
+- '12'
+- '13'
+- '14'
+- '15'
+- '16'
+- '17'
+- '18'
+- '19'
+- '20'
+- '21'
+- '22'
+- '23'
+- '24'
+- '25'
+- '26'
+- '27'
+- '28'
+- '29'
+- '30'
+- '32'
+- '33'
+- '34'
+- '35'
+- '36'
+- '37'
+- '38'
+- '39'
+- '40'
+- '41'
+- '42'
+- '43'
+- '44'
+- '45'
+- '46'
+- '48'
+- '49'
+- '50'
+- '51'
+- '52'
+- '53'
+- '54'
+- '56'
+- '57'
+- '58'
+- '60'
+- '64'
+- '65'
+- '66'
+- '67'
+- '68'
+- '69'
+- '70'
+- '71'
+- '72'
+- '73'
+- '74'
+- '75'
+- '76'
+- '77'
+- '78'
+- '80'
+- '81'
+- '82'
+- '83'
+- '84'
+- '85'
+- '86'
+- '88'
+- '89'
+- '90'
+- '92'
+- '96'
+- '97'
+- '98'
+- '99'
+- '100'
+- '101'
+- '102'
+- '104'
+- '105'
+- '106'
+- '108'
+- '112'
+- '113'
+- '114'
+- '116'
+- '120'
+- '128'
+- '129'
+- '130'
+- '131'
+- '132'
+- '133'
+- '134'
+- '135'
+- '136'
+- '137'
+- '138'
+- '139'
+- '140'
+- '141'
+- '142'
+- '144'
+- '145'
+- '146'
+- '147'
+- '148'
+- '149'
+- '150'
+- '152'
+- '153'
+- '154'
+- '156'
+- '160'
+- '161'
+- '162'
+- '163'
+- '164'
+- '165'
+- '166'
+- '168'
+- '169'
+- '170'
+- '172'
+- '176'
+- '177'
+- '178'
+- '180'
+- '184'
+- '192'
+- '193'
+- '194'
+- '195'
+- '196'
+- '197'
+- '198'
+- '200'
+- '201'
+- '202'
+- '204'
+- '208'
+- '209'
+- '210'
+- '212'
+- '216'
+- '224'
+- '225'
+- '226'
+- '228'
+- '232'
+- '240'
+- '256'
+- '257'
+- '258'
+- '259'
+- '260'
+- '261'
+- '262'
+- '263'
+- '264'
+- '265'
+- '266'
+- '267'
+- '268'
+- '269'
+- '270'
+- '272'
+- '273'
+- '274'
+- '275'
+- '276'
+- '277'
+- '278'
+- '280'
+- '281'
+- '282'
+- '284'
+- '288'
+- '289'
+- '290'
+- '291'
+- '292'
+- '293'
+- '294'
+- '296'
+- '297'
+- '298'
+- '300'
+- '304'
+- '305'
+- '306'
+- '308'
+- '312'
+- '320'
+- '321'
+- '322'
+- '323'
+- '324'
+- '325'
+- '326'
+- '328'
+- '329'
+- '330'
+- '332'
+- '336'
+- '337'
+- '338'
+- '340'
+- '344'
+- '352'
+- '353'
+- '354'
+- '356'
+- '360'
+- '368'
+- '384'
+- '385'
+- '386'
+- '387'
+- '388'
+- '389'
+- '390'
+- '392'
+- '393'
+- '394'
+- '396'
+- '400'
+- '401'
+- '402'
+- '404'
+- '408'
+- '416'
+- '417'
+- '418'
+- '420'
+- '424'
+- '432'
+- '448'
+- '449'
+- '450'
+- '452'
+- '456'
+- '464'
+- '480'
+- '512'
+- '513'
+- '514'
+- '515'
+- '516'
+- '517'
+- '518'
+- '519'
+- '520'
+- '521'
+- '522'
+- '523'
+- '524'
+- '525'
+- '526'
+- '528'
+- '529'
+- '530'
+- '531'
+- '532'
+- '533'
+- '534'
+- '536'
+- '537'
+- '538'
+- '540'
+- '544'
+- '545'
+- '546'
+- '547'
+- '548'
+- '549'
+- '550'
+- '552'
+- '553'
+- '554'
+- '556'
+- '560'
+- '561'
+- '562'
+- '564'
+- '568'
+- '576'
+- '577'
+- '578'
+- '579'
+- '580'
+- '581'
+- '582'
+- '584'
+- '585'
+- '586'
+- '588'
+- '592'
+- '593'
+- '594'
+- '596'
+- '600'
+- '608'
+- '609'
+- '610'
+- '612'
+- '616'
+- '624'
+- '640'
+- '641'
+- '642'
+- '643'
+- '644'
+- '645'
+- '646'
+- '648'
+- '649'
+- '650'
+- '652'
+- '656'
+- '657'
+- '658'
+- '660'
+- '664'
+- '672'
+- '673'
+- '674'
+- '676'
+- '680'
+- '688'
+- '704'
+- '705'
+- '706'
+- '708'
+- '712'
+- '720'
+- '736'
+- '768'
+- '769'
+- '770'
+- '771'
+- '772'
+- '773'
+- '774'
+- '776'
+- '777'
+- '778'
+- '780'
+- '784'
+- '785'
+- '786'
+- '788'
+- '792'
+- '800'
+- '801'
+- '802'
+- '804'
+- '808'
+- '816'
+- '832'
+- '833'
+- '834'
+- '836'
+- '840'
+- '848'
+- '864'
+- '896'
+- '897'
+- '898'
+- '900'
+- '904'
+- '912'
+- '928'
+- '960'
+- '1024'
+- '1025'
+- '1026'
+- '1027'
+- '1028'
+- '1029'
+- '1030'
+- '1031'
+- '1032'
+- '1033'
+- '1034'
+- '1035'
+- '1036'
+- '1037'
+- '1038'
+- '1040'
+- '1041'
+- '1042'
+- '1043'
+- '1044'
+- '1045'
+- '1046'
+- '1048'
+- '1049'
+- '1050'
+- '1052'
+- '1056'
+- '1057'
+- '1058'
+- '1059'
+- '1060'
+- '1061'
+- '1062'
+- '1064'
+- '1065'
+- '1066'
+- '1068'
+- '1072'
+- '1073'
+- '1074'
+- '1076'
+- '1080'
+- '1088'
+- '1089'
+- '1090'
+- '1091'
+- '1092'
+- '1093'
+- '1094'
+- '1096'
+- '1097'
+- '1098'
+- '1100'
+- '1104'
+- '1105'
+- '1106'
+- '1108'
+- '1112'
+- '1120'
+- '1121'
+- '1122'
+- '1124'
+- '1128'
+- '1136'
+- '1152'
+- '1153'
+- '1154'
+- '1155'
+- '1156'
+- '1157'
+- '1158'
+- '1160'
+- '1161'
+- '1162'
+- '1164'
+- '1168'
+- '1169'
+- '1170'
+- '1172'
+- '1176'
+- '1184'
+- '1185'
+- '1186'
+- '1188'
+- '1192'
+- '1200'
+- '1216'
+- '1217'
+- '1218'
+- '1220'
+- '1224'
+- '1232'
+- '1248'
+- '1280'
+- '1281'
+- '1282'
+- '1283'
+- '1284'
+- '1285'
+- '1286'
+- '1288'
+- '1289'
+- '1290'
+- '1292'
+- '1296'
+- '1297'
+- '1298'
+- '1300'
+- '1304'
+- '1312'
+- '1313'
+- '1314'
+- '1316'
+- '1320'
+- '1328'
+- '1344'
+- '1345'
+- '1346'
+- '1348'
+- '1352'
+- '1360'
+- '1376'
+- '1408'
+- '1409'
+- '1410'
+- '1412'
+- '1416'
+- '1424'
+- '1440'
+- '1472'
+- '1536'
+- '1537'
+- '1538'
+- '1539'
+- '1540'
+- '1541'
+- '1542'
+- '1544'
+- '1545'
+- '1546'
+- '1548'
+- '1552'
+- '1553'
+- '1554'
+- '1556'
+- '1560'
+- '1568'
+- '1569'
+- '1570'
+- '1572'
+- '1576'
+- '1584'
+- '1600'
+- '1601'
+- '1602'
+- '1604'
+- '1608'
+- '1616'
+- '1632'
+- '1664'
+- '1665'
+- '1666'
+- '1668'
+- '1672'
+- '1680'
+- '1696'
+- '1728'
+- '1792'
+- '1793'
+- '1794'
+- '1796'
+- '1800'
+- '1808'
+- '1824'
+- '1856'
+- '1920'
+- '2048'
+- '2049'
+- '2050'
+- '2051'
+- '2052'
+- '2053'
+- '2054'
+- '2055'
+- '2056'
+- '2057'
+- '2058'
+- '2059'
+- '2060'
+- '2061'
+- '2062'
+- '2064'
+- '2065'
+- '2066'
+- '2067'
+- '2068'
+- '2069'
+- '2070'
+- '2072'
+- '2073'
+- '2074'
+- '2076'
+- '2080'
+- '2081'
+- '2082'
+- '2083'
+- '2084'
+- '2085'
+- '2086'
+- '2088'
+- '2089'
+- '2090'
+- '2092'
+- '2096'
+- '2097'
+- '2098'
+- '2100'
+- '2104'
+- '2112'
+- '2113'
+- '2114'
+- '2115'
+- '2116'
+- '2117'
+- '2118'
+- '2120'
+- '2121'
+- '2122'
+- '2124'
+- '2128'
+- '2129'
+- '2130'
+- '2132'
+- '2136'
+- '2144'
+- '2145'
+- '2146'
+- '2148'
+- '2152'
+- '2160'
+- '2176'
+- '2177'
+- '2178'
+- '2179'
+- '2180'
+- '2181'
+- '2182'
+- '2184'
+- '2185'
+- '2186'
+- '2188'
+- '2192'
+- '2193'
+- '2194'
+- '2196'
+- '2200'
+- '2208'
+- '2209'
+- '2210'
+- '2212'
+- '2216'
+- '2224'
+- '2240'
+- '2241'
+- '2242'
+- '2244'
+- '2248'
+- '2256'
+- '2272'
+- '2304'
+- '2305'
+- '2306'
+- '2307'
+- '2308'
+- '2309'
+- '2310'
+- '2312'
+- '2313'
+- '2314'
+- '2316'
+- '2320'
+- '2321'
+- '2322'
+- '2324'
+- '2328'
+- '2336'
+- '2337'
+- '2338'
+- '2340'
+- '2344'
+- '2352'
+- '2368'
+- '2369'
+- '2370'
+- '2372'
+- '2376'
+- '2384'
+- '2400'
+- '2432'
+- '2433'
+- '2434'
+- '2436'
+- '2440'
+- '2448'
+- '2464'
+- '2496'
+- '2560'
+- '2561'
+- '2562'
+- '2563'
+- '2564'
+- '2565'
+- '2566'
+- '2568'
+- '2569'
+- '2570'
+- '2572'
+- '2576'
+- '2577'
+- '2578'
+- '2580'
+- '2584'
+- '2592'
+- '2593'
+- '2594'
+- '2596'
+- '2600'
+- '2608'
+- '2624'
+- '2625'
+- '2626'
+- '2628'
+- '2632'
+- '2640'
+- '2656'
+- '2688'
+- '2689'
+- '2690'
+- '2692'
+- '2696'
+- '2704'
+- '2720'
+- '2752'
+- '2816'
+- '2817'
+- '2818'
+- '2820'
+- '2824'
+- '2832'
+- '2848'
+- '2880'
+- '2944'
+- '3072'
+- '3073'
+- '3074'
+- '3075'
+- '3076'
+- '3077'
+- '3078'
+- '3080'
+- '3081'
+- '3082'
+- '3084'
+- '3088'
+- '3089'
+- '3090'
+- '3092'
+- '3096'
+- '3104'
+- '3105'
+- '3106'
+- '3108'
+- '3112'
+- '3120'
+- '3136'
+- '3137'
+- '3138'
+- '3140'
+- '3144'
+- '3152'
+- '3168'
+- '3200'
+- '3201'
+- '3202'
+- '3204'
+- '3208'
+- '3216'
+- '3232'
+- '3264'
+- '3328'
+- '3329'
+- '3330'
+- '3332'
+- '3336'
+- '3344'
+- '3360'
+- '3392'
+- '3456'
+- '3584'
+- '3585'
+- '3586'
+- '3588'
+- '3592'
+- '3600'
+- '3616'
+- '3648'
+- '3712'
+- '3840'
+- '4096'
+- '4097'
+- '4098'
+- '4099'
+- '4100'
+- '4101'
+- '4102'
+- '4103'
+- '4104'
+- '4105'
+- '4106'
+- '4107'
+- '4108'
+- '4109'
+- '4110'
+- '4112'
+- '4113'
+- '4114'
+- '4115'
+- '4116'
+- '4117'
+- '4118'
+- '4120'
+- '4121'
+- '4122'
+- '4124'
+- '4128'
+- '4129'
+- '4130'
+- '4131'
+- '4132'
+- '4133'
+- '4134'
+- '4136'
+- '4137'
+- '4138'
+- '4140'
+- '4144'
+- '4145'
+- '4146'
+- '4148'
+- '4152'
+- '4160'
+- '4161'
+- '4162'
+- '4163'
+- '4164'
+- '4165'
+- '4166'
+- '4168'
+- '4169'
+- '4170'
+- '4172'
+- '4176'
+- '4177'
+- '4178'
+- '4180'
+- '4184'
+- '4192'
+- '4193'
+- '4194'
+- '4196'
+- '4200'
+- '4208'
+- '4224'
+- '4225'
+- '4226'
+- '4227'
+- '4228'
+- '4229'
+- '4230'
+- '4232'
+- '4233'
+- '4234'
+- '4236'
+- '4240'
+- '4241'
+- '4242'
+- '4244'
+- '4248'
+- '4256'
+- '4257'
+- '4258'
+- '4260'
+- '4264'
+- '4272'
+- '4288'
+- '4289'
+- '4290'
+- '4292'
+- '4296'
+- '4304'
+- '4320'
+- '4352'
+- '4353'
+- '4354'
+- '4355'
+- '4356'
+- '4357'
+- '4358'
+- '4360'
+- '4361'
+- '4362'
+- '4364'
+- '4368'
+- '4369'
+- '4370'
+- '4372'
+- '4376'
+- '4384'
+- '4385'
+- '4386'
+- '4388'
+- '4392'
+- '4400'
+- '4416'
+- '4417'
+- '4418'
+- '4420'
+- '4424'
+- '4432'
+- '4448'
+- '4480'
+- '4481'
+- '4482'
+- '4484'
+- '4488'
+- '4496'
+- '4512'
+- '4544'
+- '4608'
+- '4609'
+- '4610'
+- '4611'
+- '4612'
+- '4613'
+- '4614'
+- '4616'
+- '4617'
+- '4618'
+- '4620'
+- '4624'
+- '4625'
+- '4626'
+- '4628'
+- '4632'
+- '4640'
+- '4641'
+- '4642'
+- '4644'
+- '4648'
+- '4656'
+- '4672'
+- '4673'
+- '4674'
+- '4676'
+- '4680'
+- '4688'
+- '4704'
+- '4736'
+- '4737'
+- '4738'
+- '4740'
+- '4744'
+- '4752'
+- '4768'
+- '4800'
+- '4864'
+- '4865'
+- '4866'
+- '4868'
+- '4872'
+- '4880'
+- '4896'
+- '4928'
+- '4992'
+- '5120'
+- '5121'
+- '5122'
+- '5123'
+- '5124'
+- '5125'
+- '5126'
+- '5128'
+- '5129'
+- '5130'
+- '5132'
+- '5136'
+- '5137'
+- '5138'
+- '5140'
+- '5144'
+- '5152'
+- '5153'
+- '5154'
+- '5156'
+- '5160'
+- '5168'
+- '5184'
+- '5185'
+- '5186'
+- '5188'
+- '5192'
+- '5200'
+- '5216'
+- '5248'
+- '5249'
+- '5250'
+- '5252'
+- '5256'
+- '5264'
+- '5280'
+- '5312'
+- '5376'
+- '5377'
+- '5378'
+- '5380'
+- '5384'
+- '5392'
+- '5408'
+- '5440'
+- '5504'
+- '5632'
+- '5633'
+- '5634'
+- '5636'
+- '5640'
+- '5648'
+- '5664'
+- '5696'
+- '5760'
+- '5888'
+- '6144'
+- '6145'
+- '6146'
+- '6147'
+- '6148'
+- '6149'
+- '6150'
+- '6152'
+- '6153'
+- '6154'
+- '6156'
+- '6160'
+- '6161'
+- '6162'
+- '6164'
+- '6168'
+- '6176'
+- '6177'
+- '6178'
+- '6180'
+- '6184'
+- '6192'
+- '6208'
+- '6209'
+- '6210'
+- '6212'
+- '6216'
+- '6224'
+- '6240'
+- '6272'
+- '6273'
+- '6274'
+- '6276'
+- '6280'
+- '6288'
+- '6304'
+- '6336'
+- '6400'
+- '6401'
+- '6402'
+- '6404'
+- '6408'
+- '6416'
+- '6432'
+- '6464'
+- '6528'
+- '6656'
+- '6657'
+- '6658'
+- '6660'
+- '6664'
+- '6672'
+- '6688'
+- '6720'
+- '6784'
+- '6912'
+- '7168'
+- '7169'
+- '7170'
+- '7172'
+- '7176'
+- '7184'
+- '7200'
+- '7232'
+- '7296'
+- '7424'
+- '7680'
+- '8192'
+- '8193'
+- '8194'
+- '8195'
+- '8196'
+- '8197'
+- '8198'
+- '8199'
+- '8200'
+- '8201'
+- '8202'
+- '8203'
+- '8204'
+- '8205'
+- '8206'
+- '8208'
+- '8209'
+- '8210'
+- '8211'
+- '8212'
+- '8213'
+- '8214'
+- '8216'
+- '8217'
+- '8218'
+- '8220'
+- '8224'
+- '8225'
+- '8226'
+- '8227'
+- '8228'
+- '8229'
+- '8230'
+- '8232'
+- '8233'
+- '8234'
+- '8236'
+- '8240'
+- '8241'
+- '8242'
+- '8244'
+- '8248'
+- '8256'
+- '8257'
+- '8258'
+- '8259'
+- '8260'
+- '8261'
+- '8262'
+- '8264'
+- '8265'
+- '8266'
+- '8268'
+- '8272'
+- '8273'
+- '8274'
+- '8276'
+- '8280'
+- '8288'
+- '8289'
+- '8290'
+- '8292'
+- '8296'
+- '8304'
+- '8320'
+- '8321'
+- '8322'
+- '8323'
+- '8324'
+- '8325'
+- '8326'
+- '8328'
+- '8329'
+- '8330'
+- '8332'
+- '8336'
+- '8337'
+- '8338'
+- '8340'
+- '8344'
+- '8352'
+- '8353'
+- '8354'
+- '8356'
+- '8360'
+- '8368'
+- '8384'
+- '8385'
+- '8386'
+- '8388'
+- '8392'
+- '8400'
+- '8416'
+- '8448'
+- '8449'
+- '8450'
+- '8451'
+- '8452'
+- '8453'
+- '8454'
+- '8456'
+- '8457'
+- '8458'
+- '8460'
+- '8464'
+- '8465'
+- '8466'
+- '8468'
+- '8472'
+- '8480'
+- '8481'
+- '8482'
+- '8484'
+- '8488'
+- '8496'
+- '8512'
+- '8513'
+- '8514'
+- '8516'
+- '8520'
+- '8528'
+- '8544'
+- '8576'
+- '8577'
+- '8578'
+- '8580'
+- '8584'
+- '8592'
+- '8608'
+- '8640'
+- '8704'
+- '8705'
+- '8706'
+- '8707'
+- '8708'
+- '8709'
+- '8710'
+- '8712'
+- '8713'
+- '8714'
+- '8716'
+- '8720'
+- '8721'
+- '8722'
+- '8724'
+- '8728'
+- '8736'
+- '8737'
+- '8738'
+- '8740'
+- '8744'
+- '8752'
+- '8768'
+- '8769'
+- '8770'
+- '8772'
+- '8776'
+- '8784'
+- '8800'
+- '8832'
+- '8833'
+- '8834'
+- '8836'
+- '8840'
+- '8848'
+- '8864'
+- '8896'
+- '8960'
+- '8961'
+- '8962'
+- '8964'
+- '8968'
+- '8976'
+- '8992'
+- '9024'
+- '9088'
+- '9216'
+- '9217'
+- '9218'
+- '9219'
+- '9220'
+- '9221'
+- '9222'
+- '9224'
+- '9225'
+- '9226'
+- '9228'
+- '9232'
+- '9233'
+- '9234'
+- '9236'
+- '9240'
+- '9248'
+- '9249'
+- '9250'
+- '9252'
+- '9256'
+- '9264'
+- '9280'
+- '9281'
+- '9282'
+- '9284'
+- '9288'
+- '9296'
+- '9312'
+- '9344'
+- '9345'
+- '9346'
+- '9348'
+- '9352'
+- '9360'
+- '9376'
+- '9408'
+- '9472'
+- '9473'
+- '9474'
+- '9476'
+- '9480'
+- '9488'
+- '9504'
+- '9536'
+- '9600'
+- '9728'
+- '9729'
+- '9730'
+- '9732'
+- '9736'
+- '9744'
+- '9760'
+- '9792'
+- '9856'
+- '9984'
+- '10240'
+- '10241'
+- '10242'
+- '10243'
+- '10244'
+- '10245'
+- '10246'
+- '10248'
+- '10249'
+- '10250'
+- '10252'
+- '10256'
+- '10257'
+- '10258'
+- '10260'
+- '10264'
+- '10272'
+- '10273'
+- '10274'
+- '10276'
+- '10280'
+- '10288'
+- '10304'
+- '10305'
+- '10306'
+- '10308'
+- '10312'
+- '10320'
+- '10336'
+- '10368'
+- '10369'
+- '10370'
+- '10372'
+- '10376'
+- '10384'
+- '10400'
+- '10432'
+- '10496'
+- '10497'
+- '10498'
+- '10500'
+- '10504'
+- '10512'
+- '10528'
+- '10560'
+- '10624'
+- '10752'
+- '10753'
+- '10754'
+- '10756'
+- '10760'
+- '10768'
+- '10784'
+- '10816'
+- '10880'
+- '11008'
+- '11264'
+- '11265'
+- '11266'
+- '11268'
+- '11272'
+- '11280'
+- '11296'
+- '11328'
+- '11392'
+- '11520'
+- '11776'
+- '12288'
+- '12289'
+- '12290'
+- '12291'
+- '12292'
+- '12293'
+- '12294'
+- '12296'
+- '12297'
+- '12298'
+- '12300'
+- '12304'
+- '12305'
+- '12306'
+- '12308'
+- '12312'
+- '12320'
+- '12321'
+- '12322'
+- '12324'
+- '12328'
+- '12336'
+- '12352'
+- '12353'
+- '12354'
+- '12356'
+- '12360'
+- '12368'
+- '12384'
+- '12416'
+- '12417'
+- '12418'
+- '12420'
+- '12424'
+- '12432'
+- '12448'
+- '12480'
+- '12544'
+- '12545'
+- '12546'
+- '12548'
+- '12552'
+- '12560'
+- '12576'
+- '12608'
+- '12672'
+- '12800'
+- '12801'
+- '12802'
+- '12804'
+- '12808'
+- '12816'
+- '12832'
+- '12864'
+- '12928'
+- '13056'
+- '13312'
+- '13313'
+- '13314'
+- '13316'
+- '13320'
+- '13328'
+- '13344'
+- '13376'
+- '13440'
+- '13568'
+- '13824'
+- '14336'
+- '14337'
+- '14338'
+- '14340'
+- '14344'
+- '14352'
+- '14368'
+- '14400'
+- '14464'
+- '14592'
+- '14848'
+- '15360'
+- '16384'
+- '16385'
+- '16386'
+- '16387'
+- '16388'
+- '16389'
+- '16390'
+- '16391'
+- '16392'
+- '16393'
+- '16394'
+- '16395'
+- '16396'
+- '16397'
+- '16398'
+- '16400'
+- '16401'
+- '16402'
+- '16403'
+- '16404'
+- '16405'
+- '16406'
+- '16408'
+- '16409'
+- '16410'
+- '16412'
+- '16416'
+- '16417'
+- '16418'
+- '16419'
+- '16420'
+- '16421'
+- '16422'
+- '16424'
+- '16425'
+- '16426'
+- '16428'
+- '16432'
+- '16433'
+- '16434'
+- '16436'
+- '16440'
+- '16448'
+- '16449'
+- '16450'
+- '16451'
+- '16452'
+- '16453'
+- '16454'
+- '16456'
+- '16457'
+- '16458'
+- '16460'
+- '16464'
+- '16465'
+- '16466'
+- '16468'
+- '16472'
+- '16480'
+- '16481'
+- '16482'
+- '16484'
+- '16488'
+- '16496'
+- '16512'
+- '16513'
+- '16514'
+- '16515'
+- '16516'
+- '16517'
+- '16518'
+- '16520'
+- '16521'
+- '16522'
+- '16524'
+- '16528'
+- '16529'
+- '16530'
+- '16532'
+- '16536'
+- '16544'
+- '16545'
+- '16546'
+- '16548'
+- '16552'
+- '16560'
+- '16576'
+- '16577'
+- '16578'
+- '16580'
+- '16584'
+- '16592'
+- '16608'
+- '16640'
+- '16641'
+- '16642'
+- '16643'
+- '16644'
+- '16645'
+- '16646'
+- '16648'
+- '16649'
+- '16650'
+- '16652'
+- '16656'
+- '16657'
+- '16658'
+- '16660'
+- '16664'
+- '16672'
+- '16673'
+- '16674'
+- '16676'
+- '16680'
+- '16688'
+- '16704'
+- '16705'
+- '16706'
+- '16708'
+- '16712'
+- '16720'
+- '16736'
+- '16768'
+- '16769'
+- '16770'
+- '16772'
+- '16776'
+- '16784'
+- '16800'
+- '16832'
+- '16896'
+- '16897'
+- '16898'
+- '16899'
+- '16900'
+- '16901'
+- '16902'
+- '16904'
+- '16905'
+- '16906'
+- '16908'
+- '16912'
+- '16913'
+- '16914'
+- '16916'
+- '16920'
+- '16928'
+- '16929'
+- '16930'
+- '16932'
+- '16936'
+- '16944'
+- '16960'
+- '16961'
+- '16962'
+- '16964'
+- '16968'
+- '16976'
+- '16992'
+- '17024'
+- '17025'
+- '17026'
+- '17028'
+- '17032'
+- '17040'
+- '17056'
+- '17088'
+- '17152'
+- '17153'
+- '17154'
+- '17156'
+- '17160'
+- '17168'
+- '17184'
+- '17216'
+- '17280'
+- '17408'
+- '17409'
+- '17410'
+- '17411'
+- '17412'
+- '17413'
+- '17414'
+- '17416'
+- '17417'
+- '17418'
+- '17420'
+- '17424'
+- '17425'
+- '17426'
+- '17428'
+- '17432'
+- '17440'
+- '17441'
+- '17442'
+- '17444'
+- '17448'
+- '17456'
+- '17472'
+- '17473'
+- '17474'
+- '17476'
+- '17480'
+- '17488'
+- '17504'
+- '17536'
+- '17537'
+- '17538'
+- '17540'
+- '17544'
+- '17552'
+- '17568'
+- '17600'
+- '17664'
+- '17665'
+- '17666'
+- '17668'
+- '17672'
+- '17680'
+- '17696'
+- '17728'
+- '17792'
+- '17920'
+- '17921'
+- '17922'
+- '17924'
+- '17928'
+- '17936'
+- '17952'
+- '17984'
+- '18048'
+- '18176'
+- '18432'
+- '18433'
+- '18434'
+- '18435'
+- '18436'
+- '18437'
+- '18438'
+- '18440'
+- '18441'
+- '18442'
+- '18444'
+- '18448'
+- '18449'
+- '18450'
+- '18452'
+- '18456'
+- '18464'
+- '18465'
+- '18466'
+- '18468'
+- '18472'
+- '18480'
+- '18496'
+- '18497'
+- '18498'
+- '18500'
+- '18504'
+- '18512'
+- '18528'
+- '18560'
+- '18561'
+- '18562'
+- '18564'
+- '18568'
+- '18576'
+- '18592'
+- '18624'
+- '18688'
+- '18689'
+- '18690'
+- '18692'
+- '18696'
+- '18704'
+- '18720'
+- '18752'
+- '18816'
+- '18944'
+- '18945'
+- '18946'
+- '18948'
+- '18952'
+- '18960'
+- '18976'
+- '19008'
+- '19072'
+- '19200'
+- '19456'
+- '19457'
+- '19458'
+- '19460'
+- '19464'
+- '19472'
+- '19488'
+- '19520'
+- '19584'
+- '19712'
+- '19968'
+- '20480'
+- '20481'
+- '20482'
+- '20483'
+- '20484'
+- '20485'
+- '20486'
+- '20488'
+- '20489'
+- '20490'
+- '20492'
+- '20496'
+- '20497'
+- '20498'
+- '20500'
+- '20504'
+- '20512'
+- '20513'
+- '20514'
+- '20516'
+- '20520'
+- '20528'
+- '20544'
+- '20545'
+- '20546'
+- '20548'
+- '20552'
+- '20560'
+- '20576'
+- '20608'
+- '20609'
+- '20610'
+- '20612'
+- '20616'
+- '20624'
+- '20640'
+- '20672'
+- '20736'
+- '20737'
+- '20738'
+- '20740'
+- '20744'
+- '20752'
+- '20768'
+- '20800'
+- '20864'
+- '20992'
+- '20993'
+- '20994'
+- '20996'
+- '21000'
+- '21008'
+- '21024'
+- '21056'
+- '21120'
+- '21248'
+- '21504'
+- '21505'
+- '21506'
+- '21508'
+- '21512'
+- '21520'
+- '21536'
+- '21568'
+- '21632'
+- '21760'
+- '22016'
+- '22528'
+- '22529'
+- '22530'
+- '22532'
+- '22536'
+- '22544'
+- '22560'
+- '22592'
+- '22656'
+- '22784'
+- '23040'
+- '23552'
+- '24576'
+- '24577'
+- '24578'
+- '24579'
+- '24580'
+- '24581'
+- '24582'
+- '24584'
+- '24585'
+- '24586'
+- '24588'
+- '24592'
+- '24593'
+- '24594'
+- '24596'
+- '24600'
+- '24608'
+- '24609'
+- '24610'
+- '24612'
+- '24616'
+- '24624'
+- '24640'
+- '24641'
+- '24642'
+- '24644'
+- '24648'
+- '24656'
+- '24672'
+- '24704'
+- '24705'
+- '24706'
+- '24708'
+- '24712'
+- '24720'
+- '24736'
+- '24768'
+- '24832'
+- '24833'
+- '24834'
+- '24836'
+- '24840'
+- '24848'
+- '24864'
+- '24896'
+- '24960'
+- '25088'
+- '25089'
+- '25090'
+- '25092'
+- '25096'
+- '25104'
+- '25120'
+- '25152'
+- '25216'
+- '25344'
+- '25600'
+- '25601'
+- '25602'
+- '25604'
+- '25608'
+- '25616'
+- '25632'
+- '25664'
+- '25728'
+- '25856'
+- '26112'
+- '26624'
+- '26625'
+- '26626'
+- '26628'
+- '26632'
+- '26640'
+- '26656'
+- '26688'
+- '26752'
+- '26880'
+- '27136'
+- '27648'
+- '28672'
+- '28673'
+- '28674'
+- '28676'
+- '28680'
+- '28688'
+- '28704'
+- '28736'
+- '28800'
+- '28928'
+- '29184'
+- '29696'
+- '30720'
+- '32768'
+- '32769'
+- '32770'
+- '32771'
+- '32772'
+- '32773'
+- '32774'
+- '32775'
+- '32776'
+- '32777'
+- '32778'
+- '32779'
+- '32780'
+- '32781'
+- '32782'
+- '32784'
+- '32785'
+- '32786'
+- '32787'
+- '32788'
+- '32789'
+- '32790'
+- '32792'
+- '32793'
+- '32794'
+- '32796'
+- '32800'
+- '32801'
+- '32802'
+- '32803'
+- '32804'
+- '32805'
+- '32806'
+- '32808'
+- '32809'
+- '32810'
+- '32812'
+- '32816'
+- '32817'
+- '32818'
+- '32820'
+- '32824'
+- '32832'
+- '32833'
+- '32834'
+- '32835'
+- '32836'
+- '32837'
+- '32838'
+- '32840'
+- '32841'
+- '32842'
+- '32844'
+- '32848'
+- '32849'
+- '32850'
+- '32852'
+- '32856'
+- '32864'
+- '32865'
+- '32866'
+- '32868'
+- '32872'
+- '32880'
+- '32896'
+- '32897'
+- '32898'
+- '32899'
+- '32900'
+- '32901'
+- '32902'
+- '32904'
+- '32905'
+- '32906'
+- '32908'
+- '32912'
+- '32913'
+- '32914'
+- '32916'
+- '32920'
+- '32928'
+- '32929'
+- '32930'
+- '32932'
+- '32936'
+- '32944'
+- '32960'
+- '32961'
+- '32962'
+- '32964'
+- '32968'
+- '32976'
+- '32992'
+- '33024'
+- '33025'
+- '33026'
+- '33027'
+- '33028'
+- '33029'
+- '33030'
+- '33032'
+- '33033'
+- '33034'
+- '33036'
+- '33040'
+- '33041'
+- '33042'
+- '33044'
+- '33048'
+- '33056'
+- '33057'
+- '33058'
+- '33060'
+- '33064'
+- '33072'
+- '33088'
+- '33089'
+- '33090'
+- '33092'
+- '33096'
+- '33104'
+- '33120'
+- '33152'
+- '33153'
+- '33154'
+- '33156'
+- '33160'
+- '33168'
+- '33184'
+- '33216'
+- '33280'
+- '33281'
+- '33282'
+- '33283'
+- '33284'
+- '33285'
+- '33286'
+- '33288'
+- '33289'
+- '33290'
+- '33292'
+- '33296'
+- '33297'
+- '33298'
+- '33300'
+- '33304'
+- '33312'
+- '33313'
+- '33314'
+- '33316'
+- '33320'
+- '33328'
+- '33344'
+- '33345'
+- '33346'
+- '33348'
+- '33352'
+- '33360'
+- '33376'
+- '33408'
+- '33409'
+- '33410'
+- '33412'
+- '33416'
+- '33424'
+- '33440'
+- '33472'
+- '33536'
+- '33537'
+- '33538'
+- '33540'
+- '33544'
+- '33552'
+- '33568'
+- '33600'
+- '33664'
+- '33792'
+- '33793'
+- '33794'
+- '33795'
+- '33796'
+- '33797'
+- '33798'
+- '33800'
+- '33801'
+- '33802'
+- '33804'
+- '33808'
+- '33809'
+- '33810'
+- '33812'
+- '33816'
+- '33824'
+- '33825'
+- '33826'
+- '33828'
+- '33832'
+- '33840'
+- '33856'
+- '33857'
+- '33858'
+- '33860'
+- '33864'
+- '33872'
+- '33888'
+- '33920'
+- '33921'
+- '33922'
+- '33924'
+- '33928'
+- '33936'
+- '33952'
+- '33984'
+- '34048'
+- '34049'
+- '34050'
+- '34052'
+- '34056'
+- '34064'
+- '34080'
+- '34112'
+- '34176'
+- '34304'
+- '34305'
+- '34306'
+- '34308'
+- '34312'
+- '34320'
+- '34336'
+- '34368'
+- '34432'
+- '34560'
+- '34816'
+- '34817'
+- '34818'
+- '34819'
+- '34820'
+- '34821'
+- '34822'
+- '34824'
+- '34825'
+- '34826'
+- '34828'
+- '34832'
+- '34833'
+- '34834'
+- '34836'
+- '34840'
+- '34848'
+- '34849'
+- '34850'
+- '34852'
+- '34856'
+- '34864'
+- '34880'
+- '34881'
+- '34882'
+- '34884'
+- '34888'
+- '34896'
+- '34912'
+- '34944'
+- '34945'
+- '34946'
+- '34948'
+- '34952'
+- '34960'
+- '34976'
+- '35008'
+- '35072'
+- '35073'
+- '35074'
+- '35076'
+- '35080'
+- '35088'
+- '35104'
+- '35136'
+- '35200'
+- '35328'
+- '35329'
+- '35330'
+- '35332'
+- '35336'
+- '35344'
+- '35360'
+- '35392'
+- '35456'
+- '35584'
+- '35840'
+- '35841'
+- '35842'
+- '35844'
+- '35848'
+- '35856'
+- '35872'
+- '35904'
+- '35968'
+- '36096'
+- '36352'
+- '36864'
+- '36865'
+- '36866'
+- '36867'
+- '36868'
+- '36869'
+- '36870'
+- '36872'
+- '36873'
+- '36874'
+- '36876'
+- '36880'
+- '36881'
+- '36882'
+- '36884'
+- '36888'
+- '36896'
+- '36897'
+- '36898'
+- '36900'
+- '36904'
+- '36912'
+- '36928'
+- '36929'
+- '36930'
+- '36932'
+- '36936'
+- '36944'
+- '36960'
+- '36992'
+- '36993'
+- '36994'
+- '36996'
+- '37000'
+- '37008'
+- '37024'
+- '37056'
+- '37120'
+- '37121'
+- '37122'
+- '37124'
+- '37128'
+- '37136'
+- '37152'
+- '37184'
+- '37248'
+- '37376'
+- '37377'
+- '37378'
+- '37380'
+- '37384'
+- '37392'
+- '37408'
+- '37440'
+- '37504'
+- '37632'
+- '37888'
+- '37889'
+- '37890'
+- '37892'
+- '37896'
+- '37904'
+- '37920'
+- '37952'
+- '38016'
+- '38144'
+- '38400'
+- '38912'
+- '38913'
+- '38914'
+- '38916'
+- '38920'
+- '38928'
+- '38944'
+- '38976'
+- '39040'
+- '39168'
+- '39424'
+- '39936'
+- '40960'
+- '40961'
+- '40962'
+- '40963'
+- '40964'
+- '40965'
+- '40966'
+- '40968'
+- '40969'
+- '40970'
+- '40972'
+- '40976'
+- '40977'
+- '40978'
+- '40980'
+- '40984'
+- '40992'
+- '40993'
+- '40994'
+- '40996'
+- '41000'
+- '41008'
+- '41024'
+- '41025'
+- '41026'
+- '41028'
+- '41032'
+- '41040'
+- '41056'
+- '41088'
+- '41089'
+- '41090'
+- '41092'
+- '41096'
+- '41104'
+- '41120'
+- '41152'
+- '41216'
+- '41217'
+- '41218'
+- '41220'
+- '41224'
+- '41232'
+- '41248'
+- '41280'
+- '41344'
+- '41472'
+- '41473'
+- '41474'
+- '41476'
+- '41480'
+- '41488'
+- '41504'
+- '41536'
+- '41600'
+- '41728'
+- '41984'
+- '41985'
+- '41986'
+- '41988'
+- '41992'
+- '42000'
+- '42016'
+- '42048'
+- '42112'
+- '42240'
+- '42496'
+- '43008'
+- '43009'
+- '43010'
+- '43012'
+- '43016'
+- '43024'
+- '43040'
+- '43072'
+- '43136'
+- '43264'
+- '43520'
+- '44032'
+- '45056'
+- '45057'
+- '45058'
+- '45060'
+- '45064'
+- '45072'
+- '45088'
+- '45120'
+- '45184'
+- '45312'
+- '45568'
+- '46080'
+- '47104'
+- '49152'
+- '49153'
+- '49154'
+- '49155'
+- '49156'
+- '49157'
+- '49158'
+- '49160'
+- '49161'
+- '49162'
+- '49164'
+- '49168'
+- '49169'
+- '49170'
+- '49172'
+- '49176'
+- '49184'
+- '49185'
+- '49186'
+- '49188'
+- '49192'
+- '49200'
+- '49216'
+- '49217'
+- '49218'
+- '49220'
+- '49224'
+- '49232'
+- '49248'
+- '49280'
+- '49281'
+- '49282'
+- '49284'
+- '49288'
+- '49296'
+- '49312'
+- '49344'
+- '49408'
+- '49409'
+- '49410'
+- '49412'
+- '49416'
+- '49424'
+- '49440'
+- '49472'
+- '49536'
+- '49664'
+- '49665'
+- '49666'
+- '49668'
+- '49672'
+- '49680'
+- '49696'
+- '49728'
+- '49792'
+- '49920'
+- '50176'
+- '50177'
+- '50178'
+- '50180'
+- '50184'
+- '50192'
+- '50208'
+- '50240'
+- '50304'
+- '50432'
+- '50688'
+- '51200'
+- '51201'
+- '51202'
+- '51204'
+- '51208'
+- '51216'
+- '51232'
+- '51264'
+- '51328'
+- '51456'
+- '51712'
+- '52224'
+- '53248'
+- '53249'
+- '53250'
+- '53252'
+- '53256'
+- '53264'
+- '53280'
+- '53312'
+- '53376'
+- '53504'
+- '53760'
+- '54272'
+- '55296'
+- '57344'
+- '57345'
+- '57346'
+- '57348'
+- '57352'
+- '57360'
+- '57376'
+- '57408'
+- '57472'
+- '57600'
+- '57856'
+- '58368'
+- '59392'
+- '61440'
+init: null
+input_size: null
+cmvn_file: null
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+joint_net_conf: null
+use_preprocessor: true
+token_type: char
+bpemodel: null
+non_linguistic_symbols: null
+cleaner: null
+g2p: null
+speech_volume_normalize: null
+rir_scp: null
+rir_apply_prob: 1.0
+noise_scp: null
+noise_apply_prob: 1.0
+noise_db_range: '13_15'
+specaug: null
+specaug_conf: {}
+normalize: null
+normalize_conf: {}
+label_aggregator: null
+label_aggregator_conf: {}
+model: sond
+model_conf:
+ # ctc_weight: 0.0
+ lsm_weight: 0.1
+ length_normalized_loss: true
+ max_spk_num: 16
+ # predictor_weight: 1.0
+ # predictor_bias: 1
+ # sampling_ratio: 0.75
+# speech encoder
+encoder: resnet34
+encoder_conf:
+ # pass by model, equal to feature dim
+ # input_size: 80
+ pooling_type: "window_shift"
+ pool_size: 20
+ stride: 1
+ tf2torch_tensor_name_prefix_torch: encoder
+ tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
+speaker_encoder: conv
+speaker_encoder_conf:
+ input_units: 256
+ num_layers: 3
+ num_units: 256
+ kernel_size: 1
+ dropout_rate: 0.0
+ position_encoder: null
+ out_units: 256
+ out_norm: false
+ auxiliary_states: false
+ tf2torch_tensor_name_prefix_torch: speaker_encoder
+ tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
+ci_scorer: dot
+ci_scorer_conf: {}
+cd_scorer: san
+cd_scorer_conf:
+ input_size: 512
+ output_size: 512
+ out_units: 1
+ attention_heads: 4
+ linear_units: 1024
+ num_blocks: 4
+ dropout_rate: 0.0
+ positional_dropout_rate: 0.0
+ attention_dropout_rate: 0.0
+ # use string "null" to remove input layer
+ input_layer: "null"
+ pos_enc_class: null
+ normalize_before: true
+ tf2torch_tensor_name_prefix_torch: cd_scorer
+ tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
+# post net
+decoder: fsmn
+decoder_conf:
+ in_units: 32
+ out_units: 2517
+ filter_size: 31
+ fsmn_num_layers: 6
+ dnn_num_layers: 1
+ num_memory_units: 512
+ ffn_inner_dim: 512
+ dropout_rate: 0.0
+ tf2torch_tensor_name_prefix_torch: decoder
+ tf2torch_tensor_name_prefix_tf: EAND/post_net
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: povey
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ filter_length_min: -1
+ filter_length_max: -1
+ lfr_m: 1
+ lfr_n: 1
+ dither: 0.0
+ snip_edges: false
+num_worker_count: 1
+required:
+- output_dir
+- token_list
+oss_bucket: 'null'
+version: 0.1.4
diff --git a/egs/alimeeting/diarization/sond/config_fbank.yaml b/egs/alimeeting/diarization/sond/config_fbank.yaml
new file mode 100644
index 0000000..cb4b8a9
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/config_fbank.yaml
@@ -0,0 +1,2728 @@
+config: finetune.yaml
+print_config: false
+log_level: INFO
+dry_run: false
+iterator_type: sequence
+output_dir: exp/sond
+ngpu: 1
+seed: 0
+num_workers: 16
+num_att_plot: 0
+dist_backend: nccl
+dist_init_method: env://
+dist_world_size: null
+dist_rank: null
+local_rank: 0
+dist_master_addr: null
+dist_master_port: null
+dist_launcher: null
+multiprocessing_distributed: true
+distributed: false
+unused_parameters: true
+sharded_ddp: false
+ddp_backend: pytorch_ddp
+cudnn_enabled: true
+cudnn_benchmark: false
+cudnn_deterministic: true
+collect_stats: false
+write_collected_feats: false
+max_epoch: 50
+patience: null
+val_scheduler_criterion:
+- valid
+- acc
+early_stopping_criterion:
+- valid
+- loss
+- min
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+nbest_averaging_interval: 0
+grad_clip: 5
+grad_clip_type: 2.0
+grad_noise: false
+accum_grad: 1
+no_forward_run: false
+resume: true
+train_dtype: float32
+use_amp: false
+log_interval: 50
+use_matplotlib: false
+use_tensorboard: true
+use_wandb: false
+wandb_project: null
+wandb_id: null
+wandb_entity: null
+wandb_name: null
+wandb_model_log_interval: -1
+use_pai: true
+detect_anomaly: false
+pretrain_path: null
+init_param: []
+ignore_init_mismatch: false
+freeze_param: []
+num_iters_per_epoch: null
+batch_size: 20
+valid_batch_size: null
+batch_bins: 10000
+valid_batch_bins: null
+train_shape_file:
+- /data/volume1/youyan/aishell/ark/train/speech_shape.1
+- /data/volume1/youyan/aishell/ark/train/text_shape.1
+valid_shape_file:
+- /data/volume1/youyan/aishell/ark/dev/speech_shape.1
+- /data/volume1/youyan/aishell/ark/dev/text_shape.1
+batch_type: length
+valid_batch_type: null
+fold_length:
+- 512
+- 150
+sort_in_batch: descending
+sort_batch: descending
+multiple_iterator: false
+chunk_length: 500
+chunk_shift_ratio: 0.5
+num_cache_chunks: 1024
+train_data_path_and_name_and_type:
+- - /data/volume1/youyan/aishell/ark/train/data.scp
+ - speech
+ - kaldi_ark
+- - /data/volume1/youyan/aishell/ark/train/data.text.1
+ - text
+ - text
+valid_data_path_and_name_and_type:
+- - /data/volume1/youyan/aishell/ark/dev/data.scp
+ - speech
+ - kaldi_ark
+- - /data/volume1/youyan/aishell/ark/dev/data.text.1
+ - text
+ - text
+allow_variable_data_keys: false
+max_cache_size: 0.0
+max_cache_fd: 32
+valid_max_cache_size: null
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+token_list:
+- '0'
+- '1'
+- '2'
+- '3'
+- '4'
+- '5'
+- '6'
+- '7'
+- '8'
+- '9'
+- '10'
+- '11'
+- '12'
+- '13'
+- '14'
+- '15'
+- '16'
+- '17'
+- '18'
+- '19'
+- '20'
+- '21'
+- '22'
+- '23'
+- '24'
+- '25'
+- '26'
+- '27'
+- '28'
+- '29'
+- '30'
+- '32'
+- '33'
+- '34'
+- '35'
+- '36'
+- '37'
+- '38'
+- '39'
+- '40'
+- '41'
+- '42'
+- '43'
+- '44'
+- '45'
+- '46'
+- '48'
+- '49'
+- '50'
+- '51'
+- '52'
+- '53'
+- '54'
+- '56'
+- '57'
+- '58'
+- '60'
+- '64'
+- '65'
+- '66'
+- '67'
+- '68'
+- '69'
+- '70'
+- '71'
+- '72'
+- '73'
+- '74'
+- '75'
+- '76'
+- '77'
+- '78'
+- '80'
+- '81'
+- '82'
+- '83'
+- '84'
+- '85'
+- '86'
+- '88'
+- '89'
+- '90'
+- '92'
+- '96'
+- '97'
+- '98'
+- '99'
+- '100'
+- '101'
+- '102'
+- '104'
+- '105'
+- '106'
+- '108'
+- '112'
+- '113'
+- '114'
+- '116'
+- '120'
+- '128'
+- '129'
+- '130'
+- '131'
+- '132'
+- '133'
+- '134'
+- '135'
+- '136'
+- '137'
+- '138'
+- '139'
+- '140'
+- '141'
+- '142'
+- '144'
+- '145'
+- '146'
+- '147'
+- '148'
+- '149'
+- '150'
+- '152'
+- '153'
+- '154'
+- '156'
+- '160'
+- '161'
+- '162'
+- '163'
+- '164'
+- '165'
+- '166'
+- '168'
+- '169'
+- '170'
+- '172'
+- '176'
+- '177'
+- '178'
+- '180'
+- '184'
+- '192'
+- '193'
+- '194'
+- '195'
+- '196'
+- '197'
+- '198'
+- '200'
+- '201'
+- '202'
+- '204'
+- '208'
+- '209'
+- '210'
+- '212'
+- '216'
+- '224'
+- '225'
+- '226'
+- '228'
+- '232'
+- '240'
+- '256'
+- '257'
+- '258'
+- '259'
+- '260'
+- '261'
+- '262'
+- '263'
+- '264'
+- '265'
+- '266'
+- '267'
+- '268'
+- '269'
+- '270'
+- '272'
+- '273'
+- '274'
+- '275'
+- '276'
+- '277'
+- '278'
+- '280'
+- '281'
+- '282'
+- '284'
+- '288'
+- '289'
+- '290'
+- '291'
+- '292'
+- '293'
+- '294'
+- '296'
+- '297'
+- '298'
+- '300'
+- '304'
+- '305'
+- '306'
+- '308'
+- '312'
+- '320'
+- '321'
+- '322'
+- '323'
+- '324'
+- '325'
+- '326'
+- '328'
+- '329'
+- '330'
+- '332'
+- '336'
+- '337'
+- '338'
+- '340'
+- '344'
+- '352'
+- '353'
+- '354'
+- '356'
+- '360'
+- '368'
+- '384'
+- '385'
+- '386'
+- '387'
+- '388'
+- '389'
+- '390'
+- '392'
+- '393'
+- '394'
+- '396'
+- '400'
+- '401'
+- '402'
+- '404'
+- '408'
+- '416'
+- '417'
+- '418'
+- '420'
+- '424'
+- '432'
+- '448'
+- '449'
+- '450'
+- '452'
+- '456'
+- '464'
+- '480'
+- '512'
+- '513'
+- '514'
+- '515'
+- '516'
+- '517'
+- '518'
+- '519'
+- '520'
+- '521'
+- '522'
+- '523'
+- '524'
+- '525'
+- '526'
+- '528'
+- '529'
+- '530'
+- '531'
+- '532'
+- '533'
+- '534'
+- '536'
+- '537'
+- '538'
+- '540'
+- '544'
+- '545'
+- '546'
+- '547'
+- '548'
+- '549'
+- '550'
+- '552'
+- '553'
+- '554'
+- '556'
+- '560'
+- '561'
+- '562'
+- '564'
+- '568'
+- '576'
+- '577'
+- '578'
+- '579'
+- '580'
+- '581'
+- '582'
+- '584'
+- '585'
+- '586'
+- '588'
+- '592'
+- '593'
+- '594'
+- '596'
+- '600'
+- '608'
+- '609'
+- '610'
+- '612'
+- '616'
+- '624'
+- '640'
+- '641'
+- '642'
+- '643'
+- '644'
+- '645'
+- '646'
+- '648'
+- '649'
+- '650'
+- '652'
+- '656'
+- '657'
+- '658'
+- '660'
+- '664'
+- '672'
+- '673'
+- '674'
+- '676'
+- '680'
+- '688'
+- '704'
+- '705'
+- '706'
+- '708'
+- '712'
+- '720'
+- '736'
+- '768'
+- '769'
+- '770'
+- '771'
+- '772'
+- '773'
+- '774'
+- '776'
+- '777'
+- '778'
+- '780'
+- '784'
+- '785'
+- '786'
+- '788'
+- '792'
+- '800'
+- '801'
+- '802'
+- '804'
+- '808'
+- '816'
+- '832'
+- '833'
+- '834'
+- '836'
+- '840'
+- '848'
+- '864'
+- '896'
+- '897'
+- '898'
+- '900'
+- '904'
+- '912'
+- '928'
+- '960'
+- '1024'
+- '1025'
+- '1026'
+- '1027'
+- '1028'
+- '1029'
+- '1030'
+- '1031'
+- '1032'
+- '1033'
+- '1034'
+- '1035'
+- '1036'
+- '1037'
+- '1038'
+- '1040'
+- '1041'
+- '1042'
+- '1043'
+- '1044'
+- '1045'
+- '1046'
+- '1048'
+- '1049'
+- '1050'
+- '1052'
+- '1056'
+- '1057'
+- '1058'
+- '1059'
+- '1060'
+- '1061'
+- '1062'
+- '1064'
+- '1065'
+- '1066'
+- '1068'
+- '1072'
+- '1073'
+- '1074'
+- '1076'
+- '1080'
+- '1088'
+- '1089'
+- '1090'
+- '1091'
+- '1092'
+- '1093'
+- '1094'
+- '1096'
+- '1097'
+- '1098'
+- '1100'
+- '1104'
+- '1105'
+- '1106'
+- '1108'
+- '1112'
+- '1120'
+- '1121'
+- '1122'
+- '1124'
+- '1128'
+- '1136'
+- '1152'
+- '1153'
+- '1154'
+- '1155'
+- '1156'
+- '1157'
+- '1158'
+- '1160'
+- '1161'
+- '1162'
+- '1164'
+- '1168'
+- '1169'
+- '1170'
+- '1172'
+- '1176'
+- '1184'
+- '1185'
+- '1186'
+- '1188'
+- '1192'
+- '1200'
+- '1216'
+- '1217'
+- '1218'
+- '1220'
+- '1224'
+- '1232'
+- '1248'
+- '1280'
+- '1281'
+- '1282'
+- '1283'
+- '1284'
+- '1285'
+- '1286'
+- '1288'
+- '1289'
+- '1290'
+- '1292'
+- '1296'
+- '1297'
+- '1298'
+- '1300'
+- '1304'
+- '1312'
+- '1313'
+- '1314'
+- '1316'
+- '1320'
+- '1328'
+- '1344'
+- '1345'
+- '1346'
+- '1348'
+- '1352'
+- '1360'
+- '1376'
+- '1408'
+- '1409'
+- '1410'
+- '1412'
+- '1416'
+- '1424'
+- '1440'
+- '1472'
+- '1536'
+- '1537'
+- '1538'
+- '1539'
+- '1540'
+- '1541'
+- '1542'
+- '1544'
+- '1545'
+- '1546'
+- '1548'
+- '1552'
+- '1553'
+- '1554'
+- '1556'
+- '1560'
+- '1568'
+- '1569'
+- '1570'
+- '1572'
+- '1576'
+- '1584'
+- '1600'
+- '1601'
+- '1602'
+- '1604'
+- '1608'
+- '1616'
+- '1632'
+- '1664'
+- '1665'
+- '1666'
+- '1668'
+- '1672'
+- '1680'
+- '1696'
+- '1728'
+- '1792'
+- '1793'
+- '1794'
+- '1796'
+- '1800'
+- '1808'
+- '1824'
+- '1856'
+- '1920'
+- '2048'
+- '2049'
+- '2050'
+- '2051'
+- '2052'
+- '2053'
+- '2054'
+- '2055'
+- '2056'
+- '2057'
+- '2058'
+- '2059'
+- '2060'
+- '2061'
+- '2062'
+- '2064'
+- '2065'
+- '2066'
+- '2067'
+- '2068'
+- '2069'
+- '2070'
+- '2072'
+- '2073'
+- '2074'
+- '2076'
+- '2080'
+- '2081'
+- '2082'
+- '2083'
+- '2084'
+- '2085'
+- '2086'
+- '2088'
+- '2089'
+- '2090'
+- '2092'
+- '2096'
+- '2097'
+- '2098'
+- '2100'
+- '2104'
+- '2112'
+- '2113'
+- '2114'
+- '2115'
+- '2116'
+- '2117'
+- '2118'
+- '2120'
+- '2121'
+- '2122'
+- '2124'
+- '2128'
+- '2129'
+- '2130'
+- '2132'
+- '2136'
+- '2144'
+- '2145'
+- '2146'
+- '2148'
+- '2152'
+- '2160'
+- '2176'
+- '2177'
+- '2178'
+- '2179'
+- '2180'
+- '2181'
+- '2182'
+- '2184'
+- '2185'
+- '2186'
+- '2188'
+- '2192'
+- '2193'
+- '2194'
+- '2196'
+- '2200'
+- '2208'
+- '2209'
+- '2210'
+- '2212'
+- '2216'
+- '2224'
+- '2240'
+- '2241'
+- '2242'
+- '2244'
+- '2248'
+- '2256'
+- '2272'
+- '2304'
+- '2305'
+- '2306'
+- '2307'
+- '2308'
+- '2309'
+- '2310'
+- '2312'
+- '2313'
+- '2314'
+- '2316'
+- '2320'
+- '2321'
+- '2322'
+- '2324'
+- '2328'
+- '2336'
+- '2337'
+- '2338'
+- '2340'
+- '2344'
+- '2352'
+- '2368'
+- '2369'
+- '2370'
+- '2372'
+- '2376'
+- '2384'
+- '2400'
+- '2432'
+- '2433'
+- '2434'
+- '2436'
+- '2440'
+- '2448'
+- '2464'
+- '2496'
+- '2560'
+- '2561'
+- '2562'
+- '2563'
+- '2564'
+- '2565'
+- '2566'
+- '2568'
+- '2569'
+- '2570'
+- '2572'
+- '2576'
+- '2577'
+- '2578'
+- '2580'
+- '2584'
+- '2592'
+- '2593'
+- '2594'
+- '2596'
+- '2600'
+- '2608'
+- '2624'
+- '2625'
+- '2626'
+- '2628'
+- '2632'
+- '2640'
+- '2656'
+- '2688'
+- '2689'
+- '2690'
+- '2692'
+- '2696'
+- '2704'
+- '2720'
+- '2752'
+- '2816'
+- '2817'
+- '2818'
+- '2820'
+- '2824'
+- '2832'
+- '2848'
+- '2880'
+- '2944'
+- '3072'
+- '3073'
+- '3074'
+- '3075'
+- '3076'
+- '3077'
+- '3078'
+- '3080'
+- '3081'
+- '3082'
+- '3084'
+- '3088'
+- '3089'
+- '3090'
+- '3092'
+- '3096'
+- '3104'
+- '3105'
+- '3106'
+- '3108'
+- '3112'
+- '3120'
+- '3136'
+- '3137'
+- '3138'
+- '3140'
+- '3144'
+- '3152'
+- '3168'
+- '3200'
+- '3201'
+- '3202'
+- '3204'
+- '3208'
+- '3216'
+- '3232'
+- '3264'
+- '3328'
+- '3329'
+- '3330'
+- '3332'
+- '3336'
+- '3344'
+- '3360'
+- '3392'
+- '3456'
+- '3584'
+- '3585'
+- '3586'
+- '3588'
+- '3592'
+- '3600'
+- '3616'
+- '3648'
+- '3712'
+- '3840'
+- '4096'
+- '4097'
+- '4098'
+- '4099'
+- '4100'
+- '4101'
+- '4102'
+- '4103'
+- '4104'
+- '4105'
+- '4106'
+- '4107'
+- '4108'
+- '4109'
+- '4110'
+- '4112'
+- '4113'
+- '4114'
+- '4115'
+- '4116'
+- '4117'
+- '4118'
+- '4120'
+- '4121'
+- '4122'
+- '4124'
+- '4128'
+- '4129'
+- '4130'
+- '4131'
+- '4132'
+- '4133'
+- '4134'
+- '4136'
+- '4137'
+- '4138'
+- '4140'
+- '4144'
+- '4145'
+- '4146'
+- '4148'
+- '4152'
+- '4160'
+- '4161'
+- '4162'
+- '4163'
+- '4164'
+- '4165'
+- '4166'
+- '4168'
+- '4169'
+- '4170'
+- '4172'
+- '4176'
+- '4177'
+- '4178'
+- '4180'
+- '4184'
+- '4192'
+- '4193'
+- '4194'
+- '4196'
+- '4200'
+- '4208'
+- '4224'
+- '4225'
+- '4226'
+- '4227'
+- '4228'
+- '4229'
+- '4230'
+- '4232'
+- '4233'
+- '4234'
+- '4236'
+- '4240'
+- '4241'
+- '4242'
+- '4244'
+- '4248'
+- '4256'
+- '4257'
+- '4258'
+- '4260'
+- '4264'
+- '4272'
+- '4288'
+- '4289'
+- '4290'
+- '4292'
+- '4296'
+- '4304'
+- '4320'
+- '4352'
+- '4353'
+- '4354'
+- '4355'
+- '4356'
+- '4357'
+- '4358'
+- '4360'
+- '4361'
+- '4362'
+- '4364'
+- '4368'
+- '4369'
+- '4370'
+- '4372'
+- '4376'
+- '4384'
+- '4385'
+- '4386'
+- '4388'
+- '4392'
+- '4400'
+- '4416'
+- '4417'
+- '4418'
+- '4420'
+- '4424'
+- '4432'
+- '4448'
+- '4480'
+- '4481'
+- '4482'
+- '4484'
+- '4488'
+- '4496'
+- '4512'
+- '4544'
+- '4608'
+- '4609'
+- '4610'
+- '4611'
+- '4612'
+- '4613'
+- '4614'
+- '4616'
+- '4617'
+- '4618'
+- '4620'
+- '4624'
+- '4625'
+- '4626'
+- '4628'
+- '4632'
+- '4640'
+- '4641'
+- '4642'
+- '4644'
+- '4648'
+- '4656'
+- '4672'
+- '4673'
+- '4674'
+- '4676'
+- '4680'
+- '4688'
+- '4704'
+- '4736'
+- '4737'
+- '4738'
+- '4740'
+- '4744'
+- '4752'
+- '4768'
+- '4800'
+- '4864'
+- '4865'
+- '4866'
+- '4868'
+- '4872'
+- '4880'
+- '4896'
+- '4928'
+- '4992'
+- '5120'
+- '5121'
+- '5122'
+- '5123'
+- '5124'
+- '5125'
+- '5126'
+- '5128'
+- '5129'
+- '5130'
+- '5132'
+- '5136'
+- '5137'
+- '5138'
+- '5140'
+- '5144'
+- '5152'
+- '5153'
+- '5154'
+- '5156'
+- '5160'
+- '5168'
+- '5184'
+- '5185'
+- '5186'
+- '5188'
+- '5192'
+- '5200'
+- '5216'
+- '5248'
+- '5249'
+- '5250'
+- '5252'
+- '5256'
+- '5264'
+- '5280'
+- '5312'
+- '5376'
+- '5377'
+- '5378'
+- '5380'
+- '5384'
+- '5392'
+- '5408'
+- '5440'
+- '5504'
+- '5632'
+- '5633'
+- '5634'
+- '5636'
+- '5640'
+- '5648'
+- '5664'
+- '5696'
+- '5760'
+- '5888'
+- '6144'
+- '6145'
+- '6146'
+- '6147'
+- '6148'
+- '6149'
+- '6150'
+- '6152'
+- '6153'
+- '6154'
+- '6156'
+- '6160'
+- '6161'
+- '6162'
+- '6164'
+- '6168'
+- '6176'
+- '6177'
+- '6178'
+- '6180'
+- '6184'
+- '6192'
+- '6208'
+- '6209'
+- '6210'
+- '6212'
+- '6216'
+- '6224'
+- '6240'
+- '6272'
+- '6273'
+- '6274'
+- '6276'
+- '6280'
+- '6288'
+- '6304'
+- '6336'
+- '6400'
+- '6401'
+- '6402'
+- '6404'
+- '6408'
+- '6416'
+- '6432'
+- '6464'
+- '6528'
+- '6656'
+- '6657'
+- '6658'
+- '6660'
+- '6664'
+- '6672'
+- '6688'
+- '6720'
+- '6784'
+- '6912'
+- '7168'
+- '7169'
+- '7170'
+- '7172'
+- '7176'
+- '7184'
+- '7200'
+- '7232'
+- '7296'
+- '7424'
+- '7680'
+- '8192'
+- '8193'
+- '8194'
+- '8195'
+- '8196'
+- '8197'
+- '8198'
+- '8199'
+- '8200'
+- '8201'
+- '8202'
+- '8203'
+- '8204'
+- '8205'
+- '8206'
+- '8208'
+- '8209'
+- '8210'
+- '8211'
+- '8212'
+- '8213'
+- '8214'
+- '8216'
+- '8217'
+- '8218'
+- '8220'
+- '8224'
+- '8225'
+- '8226'
+- '8227'
+- '8228'
+- '8229'
+- '8230'
+- '8232'
+- '8233'
+- '8234'
+- '8236'
+- '8240'
+- '8241'
+- '8242'
+- '8244'
+- '8248'
+- '8256'
+- '8257'
+- '8258'
+- '8259'
+- '8260'
+- '8261'
+- '8262'
+- '8264'
+- '8265'
+- '8266'
+- '8268'
+- '8272'
+- '8273'
+- '8274'
+- '8276'
+- '8280'
+- '8288'
+- '8289'
+- '8290'
+- '8292'
+- '8296'
+- '8304'
+- '8320'
+- '8321'
+- '8322'
+- '8323'
+- '8324'
+- '8325'
+- '8326'
+- '8328'
+- '8329'
+- '8330'
+- '8332'
+- '8336'
+- '8337'
+- '8338'
+- '8340'
+- '8344'
+- '8352'
+- '8353'
+- '8354'
+- '8356'
+- '8360'
+- '8368'
+- '8384'
+- '8385'
+- '8386'
+- '8388'
+- '8392'
+- '8400'
+- '8416'
+- '8448'
+- '8449'
+- '8450'
+- '8451'
+- '8452'
+- '8453'
+- '8454'
+- '8456'
+- '8457'
+- '8458'
+- '8460'
+- '8464'
+- '8465'
+- '8466'
+- '8468'
+- '8472'
+- '8480'
+- '8481'
+- '8482'
+- '8484'
+- '8488'
+- '8496'
+- '8512'
+- '8513'
+- '8514'
+- '8516'
+- '8520'
+- '8528'
+- '8544'
+- '8576'
+- '8577'
+- '8578'
+- '8580'
+- '8584'
+- '8592'
+- '8608'
+- '8640'
+- '8704'
+- '8705'
+- '8706'
+- '8707'
+- '8708'
+- '8709'
+- '8710'
+- '8712'
+- '8713'
+- '8714'
+- '8716'
+- '8720'
+- '8721'
+- '8722'
+- '8724'
+- '8728'
+- '8736'
+- '8737'
+- '8738'
+- '8740'
+- '8744'
+- '8752'
+- '8768'
+- '8769'
+- '8770'
+- '8772'
+- '8776'
+- '8784'
+- '8800'
+- '8832'
+- '8833'
+- '8834'
+- '8836'
+- '8840'
+- '8848'
+- '8864'
+- '8896'
+- '8960'
+- '8961'
+- '8962'
+- '8964'
+- '8968'
+- '8976'
+- '8992'
+- '9024'
+- '9088'
+- '9216'
+- '9217'
+- '9218'
+- '9219'
+- '9220'
+- '9221'
+- '9222'
+- '9224'
+- '9225'
+- '9226'
+- '9228'
+- '9232'
+- '9233'
+- '9234'
+- '9236'
+- '9240'
+- '9248'
+- '9249'
+- '9250'
+- '9252'
+- '9256'
+- '9264'
+- '9280'
+- '9281'
+- '9282'
+- '9284'
+- '9288'
+- '9296'
+- '9312'
+- '9344'
+- '9345'
+- '9346'
+- '9348'
+- '9352'
+- '9360'
+- '9376'
+- '9408'
+- '9472'
+- '9473'
+- '9474'
+- '9476'
+- '9480'
+- '9488'
+- '9504'
+- '9536'
+- '9600'
+- '9728'
+- '9729'
+- '9730'
+- '9732'
+- '9736'
+- '9744'
+- '9760'
+- '9792'
+- '9856'
+- '9984'
+- '10240'
+- '10241'
+- '10242'
+- '10243'
+- '10244'
+- '10245'
+- '10246'
+- '10248'
+- '10249'
+- '10250'
+- '10252'
+- '10256'
+- '10257'
+- '10258'
+- '10260'
+- '10264'
+- '10272'
+- '10273'
+- '10274'
+- '10276'
+- '10280'
+- '10288'
+- '10304'
+- '10305'
+- '10306'
+- '10308'
+- '10312'
+- '10320'
+- '10336'
+- '10368'
+- '10369'
+- '10370'
+- '10372'
+- '10376'
+- '10384'
+- '10400'
+- '10432'
+- '10496'
+- '10497'
+- '10498'
+- '10500'
+- '10504'
+- '10512'
+- '10528'
+- '10560'
+- '10624'
+- '10752'
+- '10753'
+- '10754'
+- '10756'
+- '10760'
+- '10768'
+- '10784'
+- '10816'
+- '10880'
+- '11008'
+- '11264'
+- '11265'
+- '11266'
+- '11268'
+- '11272'
+- '11280'
+- '11296'
+- '11328'
+- '11392'
+- '11520'
+- '11776'
+- '12288'
+- '12289'
+- '12290'
+- '12291'
+- '12292'
+- '12293'
+- '12294'
+- '12296'
+- '12297'
+- '12298'
+- '12300'
+- '12304'
+- '12305'
+- '12306'
+- '12308'
+- '12312'
+- '12320'
+- '12321'
+- '12322'
+- '12324'
+- '12328'
+- '12336'
+- '12352'
+- '12353'
+- '12354'
+- '12356'
+- '12360'
+- '12368'
+- '12384'
+- '12416'
+- '12417'
+- '12418'
+- '12420'
+- '12424'
+- '12432'
+- '12448'
+- '12480'
+- '12544'
+- '12545'
+- '12546'
+- '12548'
+- '12552'
+- '12560'
+- '12576'
+- '12608'
+- '12672'
+- '12800'
+- '12801'
+- '12802'
+- '12804'
+- '12808'
+- '12816'
+- '12832'
+- '12864'
+- '12928'
+- '13056'
+- '13312'
+- '13313'
+- '13314'
+- '13316'
+- '13320'
+- '13328'
+- '13344'
+- '13376'
+- '13440'
+- '13568'
+- '13824'
+- '14336'
+- '14337'
+- '14338'
+- '14340'
+- '14344'
+- '14352'
+- '14368'
+- '14400'
+- '14464'
+- '14592'
+- '14848'
+- '15360'
+- '16384'
+- '16385'
+- '16386'
+- '16387'
+- '16388'
+- '16389'
+- '16390'
+- '16391'
+- '16392'
+- '16393'
+- '16394'
+- '16395'
+- '16396'
+- '16397'
+- '16398'
+- '16400'
+- '16401'
+- '16402'
+- '16403'
+- '16404'
+- '16405'
+- '16406'
+- '16408'
+- '16409'
+- '16410'
+- '16412'
+- '16416'
+- '16417'
+- '16418'
+- '16419'
+- '16420'
+- '16421'
+- '16422'
+- '16424'
+- '16425'
+- '16426'
+- '16428'
+- '16432'
+- '16433'
+- '16434'
+- '16436'
+- '16440'
+- '16448'
+- '16449'
+- '16450'
+- '16451'
+- '16452'
+- '16453'
+- '16454'
+- '16456'
+- '16457'
+- '16458'
+- '16460'
+- '16464'
+- '16465'
+- '16466'
+- '16468'
+- '16472'
+- '16480'
+- '16481'
+- '16482'
+- '16484'
+- '16488'
+- '16496'
+- '16512'
+- '16513'
+- '16514'
+- '16515'
+- '16516'
+- '16517'
+- '16518'
+- '16520'
+- '16521'
+- '16522'
+- '16524'
+- '16528'
+- '16529'
+- '16530'
+- '16532'
+- '16536'
+- '16544'
+- '16545'
+- '16546'
+- '16548'
+- '16552'
+- '16560'
+- '16576'
+- '16577'
+- '16578'
+- '16580'
+- '16584'
+- '16592'
+- '16608'
+- '16640'
+- '16641'
+- '16642'
+- '16643'
+- '16644'
+- '16645'
+- '16646'
+- '16648'
+- '16649'
+- '16650'
+- '16652'
+- '16656'
+- '16657'
+- '16658'
+- '16660'
+- '16664'
+- '16672'
+- '16673'
+- '16674'
+- '16676'
+- '16680'
+- '16688'
+- '16704'
+- '16705'
+- '16706'
+- '16708'
+- '16712'
+- '16720'
+- '16736'
+- '16768'
+- '16769'
+- '16770'
+- '16772'
+- '16776'
+- '16784'
+- '16800'
+- '16832'
+- '16896'
+- '16897'
+- '16898'
+- '16899'
+- '16900'
+- '16901'
+- '16902'
+- '16904'
+- '16905'
+- '16906'
+- '16908'
+- '16912'
+- '16913'
+- '16914'
+- '16916'
+- '16920'
+- '16928'
+- '16929'
+- '16930'
+- '16932'
+- '16936'
+- '16944'
+- '16960'
+- '16961'
+- '16962'
+- '16964'
+- '16968'
+- '16976'
+- '16992'
+- '17024'
+- '17025'
+- '17026'
+- '17028'
+- '17032'
+- '17040'
+- '17056'
+- '17088'
+- '17152'
+- '17153'
+- '17154'
+- '17156'
+- '17160'
+- '17168'
+- '17184'
+- '17216'
+- '17280'
+- '17408'
+- '17409'
+- '17410'
+- '17411'
+- '17412'
+- '17413'
+- '17414'
+- '17416'
+- '17417'
+- '17418'
+- '17420'
+- '17424'
+- '17425'
+- '17426'
+- '17428'
+- '17432'
+- '17440'
+- '17441'
+- '17442'
+- '17444'
+- '17448'
+- '17456'
+- '17472'
+- '17473'
+- '17474'
+- '17476'
+- '17480'
+- '17488'
+- '17504'
+- '17536'
+- '17537'
+- '17538'
+- '17540'
+- '17544'
+- '17552'
+- '17568'
+- '17600'
+- '17664'
+- '17665'
+- '17666'
+- '17668'
+- '17672'
+- '17680'
+- '17696'
+- '17728'
+- '17792'
+- '17920'
+- '17921'
+- '17922'
+- '17924'
+- '17928'
+- '17936'
+- '17952'
+- '17984'
+- '18048'
+- '18176'
+- '18432'
+- '18433'
+- '18434'
+- '18435'
+- '18436'
+- '18437'
+- '18438'
+- '18440'
+- '18441'
+- '18442'
+- '18444'
+- '18448'
+- '18449'
+- '18450'
+- '18452'
+- '18456'
+- '18464'
+- '18465'
+- '18466'
+- '18468'
+- '18472'
+- '18480'
+- '18496'
+- '18497'
+- '18498'
+- '18500'
+- '18504'
+- '18512'
+- '18528'
+- '18560'
+- '18561'
+- '18562'
+- '18564'
+- '18568'
+- '18576'
+- '18592'
+- '18624'
+- '18688'
+- '18689'
+- '18690'
+- '18692'
+- '18696'
+- '18704'
+- '18720'
+- '18752'
+- '18816'
+- '18944'
+- '18945'
+- '18946'
+- '18948'
+- '18952'
+- '18960'
+- '18976'
+- '19008'
+- '19072'
+- '19200'
+- '19456'
+- '19457'
+- '19458'
+- '19460'
+- '19464'
+- '19472'
+- '19488'
+- '19520'
+- '19584'
+- '19712'
+- '19968'
+- '20480'
+- '20481'
+- '20482'
+- '20483'
+- '20484'
+- '20485'
+- '20486'
+- '20488'
+- '20489'
+- '20490'
+- '20492'
+- '20496'
+- '20497'
+- '20498'
+- '20500'
+- '20504'
+- '20512'
+- '20513'
+- '20514'
+- '20516'
+- '20520'
+- '20528'
+- '20544'
+- '20545'
+- '20546'
+- '20548'
+- '20552'
+- '20560'
+- '20576'
+- '20608'
+- '20609'
+- '20610'
+- '20612'
+- '20616'
+- '20624'
+- '20640'
+- '20672'
+- '20736'
+- '20737'
+- '20738'
+- '20740'
+- '20744'
+- '20752'
+- '20768'
+- '20800'
+- '20864'
+- '20992'
+- '20993'
+- '20994'
+- '20996'
+- '21000'
+- '21008'
+- '21024'
+- '21056'
+- '21120'
+- '21248'
+- '21504'
+- '21505'
+- '21506'
+- '21508'
+- '21512'
+- '21520'
+- '21536'
+- '21568'
+- '21632'
+- '21760'
+- '22016'
+- '22528'
+- '22529'
+- '22530'
+- '22532'
+- '22536'
+- '22544'
+- '22560'
+- '22592'
+- '22656'
+- '22784'
+- '23040'
+- '23552'
+- '24576'
+- '24577'
+- '24578'
+- '24579'
+- '24580'
+- '24581'
+- '24582'
+- '24584'
+- '24585'
+- '24586'
+- '24588'
+- '24592'
+- '24593'
+- '24594'
+- '24596'
+- '24600'
+- '24608'
+- '24609'
+- '24610'
+- '24612'
+- '24616'
+- '24624'
+- '24640'
+- '24641'
+- '24642'
+- '24644'
+- '24648'
+- '24656'
+- '24672'
+- '24704'
+- '24705'
+- '24706'
+- '24708'
+- '24712'
+- '24720'
+- '24736'
+- '24768'
+- '24832'
+- '24833'
+- '24834'
+- '24836'
+- '24840'
+- '24848'
+- '24864'
+- '24896'
+- '24960'
+- '25088'
+- '25089'
+- '25090'
+- '25092'
+- '25096'
+- '25104'
+- '25120'
+- '25152'
+- '25216'
+- '25344'
+- '25600'
+- '25601'
+- '25602'
+- '25604'
+- '25608'
+- '25616'
+- '25632'
+- '25664'
+- '25728'
+- '25856'
+- '26112'
+- '26624'
+- '26625'
+- '26626'
+- '26628'
+- '26632'
+- '26640'
+- '26656'
+- '26688'
+- '26752'
+- '26880'
+- '27136'
+- '27648'
+- '28672'
+- '28673'
+- '28674'
+- '28676'
+- '28680'
+- '28688'
+- '28704'
+- '28736'
+- '28800'
+- '28928'
+- '29184'
+- '29696'
+- '30720'
+- '32768'
+- '32769'
+- '32770'
+- '32771'
+- '32772'
+- '32773'
+- '32774'
+- '32775'
+- '32776'
+- '32777'
+- '32778'
+- '32779'
+- '32780'
+- '32781'
+- '32782'
+- '32784'
+- '32785'
+- '32786'
+- '32787'
+- '32788'
+- '32789'
+- '32790'
+- '32792'
+- '32793'
+- '32794'
+- '32796'
+- '32800'
+- '32801'
+- '32802'
+- '32803'
+- '32804'
+- '32805'
+- '32806'
+- '32808'
+- '32809'
+- '32810'
+- '32812'
+- '32816'
+- '32817'
+- '32818'
+- '32820'
+- '32824'
+- '32832'
+- '32833'
+- '32834'
+- '32835'
+- '32836'
+- '32837'
+- '32838'
+- '32840'
+- '32841'
+- '32842'
+- '32844'
+- '32848'
+- '32849'
+- '32850'
+- '32852'
+- '32856'
+- '32864'
+- '32865'
+- '32866'
+- '32868'
+- '32872'
+- '32880'
+- '32896'
+- '32897'
+- '32898'
+- '32899'
+- '32900'
+- '32901'
+- '32902'
+- '32904'
+- '32905'
+- '32906'
+- '32908'
+- '32912'
+- '32913'
+- '32914'
+- '32916'
+- '32920'
+- '32928'
+- '32929'
+- '32930'
+- '32932'
+- '32936'
+- '32944'
+- '32960'
+- '32961'
+- '32962'
+- '32964'
+- '32968'
+- '32976'
+- '32992'
+- '33024'
+- '33025'
+- '33026'
+- '33027'
+- '33028'
+- '33029'
+- '33030'
+- '33032'
+- '33033'
+- '33034'
+- '33036'
+- '33040'
+- '33041'
+- '33042'
+- '33044'
+- '33048'
+- '33056'
+- '33057'
+- '33058'
+- '33060'
+- '33064'
+- '33072'
+- '33088'
+- '33089'
+- '33090'
+- '33092'
+- '33096'
+- '33104'
+- '33120'
+- '33152'
+- '33153'
+- '33154'
+- '33156'
+- '33160'
+- '33168'
+- '33184'
+- '33216'
+- '33280'
+- '33281'
+- '33282'
+- '33283'
+- '33284'
+- '33285'
+- '33286'
+- '33288'
+- '33289'
+- '33290'
+- '33292'
+- '33296'
+- '33297'
+- '33298'
+- '33300'
+- '33304'
+- '33312'
+- '33313'
+- '33314'
+- '33316'
+- '33320'
+- '33328'
+- '33344'
+- '33345'
+- '33346'
+- '33348'
+- '33352'
+- '33360'
+- '33376'
+- '33408'
+- '33409'
+- '33410'
+- '33412'
+- '33416'
+- '33424'
+- '33440'
+- '33472'
+- '33536'
+- '33537'
+- '33538'
+- '33540'
+- '33544'
+- '33552'
+- '33568'
+- '33600'
+- '33664'
+- '33792'
+- '33793'
+- '33794'
+- '33795'
+- '33796'
+- '33797'
+- '33798'
+- '33800'
+- '33801'
+- '33802'
+- '33804'
+- '33808'
+- '33809'
+- '33810'
+- '33812'
+- '33816'
+- '33824'
+- '33825'
+- '33826'
+- '33828'
+- '33832'
+- '33840'
+- '33856'
+- '33857'
+- '33858'
+- '33860'
+- '33864'
+- '33872'
+- '33888'
+- '33920'
+- '33921'
+- '33922'
+- '33924'
+- '33928'
+- '33936'
+- '33952'
+- '33984'
+- '34048'
+- '34049'
+- '34050'
+- '34052'
+- '34056'
+- '34064'
+- '34080'
+- '34112'
+- '34176'
+- '34304'
+- '34305'
+- '34306'
+- '34308'
+- '34312'
+- '34320'
+- '34336'
+- '34368'
+- '34432'
+- '34560'
+- '34816'
+- '34817'
+- '34818'
+- '34819'
+- '34820'
+- '34821'
+- '34822'
+- '34824'
+- '34825'
+- '34826'
+- '34828'
+- '34832'
+- '34833'
+- '34834'
+- '34836'
+- '34840'
+- '34848'
+- '34849'
+- '34850'
+- '34852'
+- '34856'
+- '34864'
+- '34880'
+- '34881'
+- '34882'
+- '34884'
+- '34888'
+- '34896'
+- '34912'
+- '34944'
+- '34945'
+- '34946'
+- '34948'
+- '34952'
+- '34960'
+- '34976'
+- '35008'
+- '35072'
+- '35073'
+- '35074'
+- '35076'
+- '35080'
+- '35088'
+- '35104'
+- '35136'
+- '35200'
+- '35328'
+- '35329'
+- '35330'
+- '35332'
+- '35336'
+- '35344'
+- '35360'
+- '35392'
+- '35456'
+- '35584'
+- '35840'
+- '35841'
+- '35842'
+- '35844'
+- '35848'
+- '35856'
+- '35872'
+- '35904'
+- '35968'
+- '36096'
+- '36352'
+- '36864'
+- '36865'
+- '36866'
+- '36867'
+- '36868'
+- '36869'
+- '36870'
+- '36872'
+- '36873'
+- '36874'
+- '36876'
+- '36880'
+- '36881'
+- '36882'
+- '36884'
+- '36888'
+- '36896'
+- '36897'
+- '36898'
+- '36900'
+- '36904'
+- '36912'
+- '36928'
+- '36929'
+- '36930'
+- '36932'
+- '36936'
+- '36944'
+- '36960'
+- '36992'
+- '36993'
+- '36994'
+- '36996'
+- '37000'
+- '37008'
+- '37024'
+- '37056'
+- '37120'
+- '37121'
+- '37122'
+- '37124'
+- '37128'
+- '37136'
+- '37152'
+- '37184'
+- '37248'
+- '37376'
+- '37377'
+- '37378'
+- '37380'
+- '37384'
+- '37392'
+- '37408'
+- '37440'
+- '37504'
+- '37632'
+- '37888'
+- '37889'
+- '37890'
+- '37892'
+- '37896'
+- '37904'
+- '37920'
+- '37952'
+- '38016'
+- '38144'
+- '38400'
+- '38912'
+- '38913'
+- '38914'
+- '38916'
+- '38920'
+- '38928'
+- '38944'
+- '38976'
+- '39040'
+- '39168'
+- '39424'
+- '39936'
+- '40960'
+- '40961'
+- '40962'
+- '40963'
+- '40964'
+- '40965'
+- '40966'
+- '40968'
+- '40969'
+- '40970'
+- '40972'
+- '40976'
+- '40977'
+- '40978'
+- '40980'
+- '40984'
+- '40992'
+- '40993'
+- '40994'
+- '40996'
+- '41000'
+- '41008'
+- '41024'
+- '41025'
+- '41026'
+- '41028'
+- '41032'
+- '41040'
+- '41056'
+- '41088'
+- '41089'
+- '41090'
+- '41092'
+- '41096'
+- '41104'
+- '41120'
+- '41152'
+- '41216'
+- '41217'
+- '41218'
+- '41220'
+- '41224'
+- '41232'
+- '41248'
+- '41280'
+- '41344'
+- '41472'
+- '41473'
+- '41474'
+- '41476'
+- '41480'
+- '41488'
+- '41504'
+- '41536'
+- '41600'
+- '41728'
+- '41984'
+- '41985'
+- '41986'
+- '41988'
+- '41992'
+- '42000'
+- '42016'
+- '42048'
+- '42112'
+- '42240'
+- '42496'
+- '43008'
+- '43009'
+- '43010'
+- '43012'
+- '43016'
+- '43024'
+- '43040'
+- '43072'
+- '43136'
+- '43264'
+- '43520'
+- '44032'
+- '45056'
+- '45057'
+- '45058'
+- '45060'
+- '45064'
+- '45072'
+- '45088'
+- '45120'
+- '45184'
+- '45312'
+- '45568'
+- '46080'
+- '47104'
+- '49152'
+- '49153'
+- '49154'
+- '49155'
+- '49156'
+- '49157'
+- '49158'
+- '49160'
+- '49161'
+- '49162'
+- '49164'
+- '49168'
+- '49169'
+- '49170'
+- '49172'
+- '49176'
+- '49184'
+- '49185'
+- '49186'
+- '49188'
+- '49192'
+- '49200'
+- '49216'
+- '49217'
+- '49218'
+- '49220'
+- '49224'
+- '49232'
+- '49248'
+- '49280'
+- '49281'
+- '49282'
+- '49284'
+- '49288'
+- '49296'
+- '49312'
+- '49344'
+- '49408'
+- '49409'
+- '49410'
+- '49412'
+- '49416'
+- '49424'
+- '49440'
+- '49472'
+- '49536'
+- '49664'
+- '49665'
+- '49666'
+- '49668'
+- '49672'
+- '49680'
+- '49696'
+- '49728'
+- '49792'
+- '49920'
+- '50176'
+- '50177'
+- '50178'
+- '50180'
+- '50184'
+- '50192'
+- '50208'
+- '50240'
+- '50304'
+- '50432'
+- '50688'
+- '51200'
+- '51201'
+- '51202'
+- '51204'
+- '51208'
+- '51216'
+- '51232'
+- '51264'
+- '51328'
+- '51456'
+- '51712'
+- '52224'
+- '53248'
+- '53249'
+- '53250'
+- '53252'
+- '53256'
+- '53264'
+- '53280'
+- '53312'
+- '53376'
+- '53504'
+- '53760'
+- '54272'
+- '55296'
+- '57344'
+- '57345'
+- '57346'
+- '57348'
+- '57352'
+- '57360'
+- '57376'
+- '57408'
+- '57472'
+- '57600'
+- '57856'
+- '58368'
+- '59392'
+- '61440'
+init: null
+input_size: 80
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+joint_net_conf: null
+use_preprocessor: true
+token_type: char
+bpemodel: null
+non_linguistic_symbols: null
+cleaner: null
+g2p: null
+speech_volume_normalize: null
+rir_scp: null
+rir_apply_prob: 1.0
+noise_scp: null
+noise_apply_prob: 1.0
+noise_db_range: '13_15'
+frontend: null
+frontend_conf: {}
+specaug: null
+specaug_conf: {}
+normalize: null
+normalize_conf: {}
+label_aggregator: null
+label_aggregator_conf: {}
+model: sond
+model_conf:
+ # ctc_weight: 0.0
+ lsm_weight: 0.1
+ length_normalized_loss: true
+ max_spk_num: 16
+ # predictor_weight: 1.0
+ # predictor_bias: 1
+ # sampling_ratio: 0.75
+# speech encoder
+encoder: resnet34
+encoder_conf:
+ # pass by model, equal to feature dim
+ # input_size: 80
+ pooling_type: "window_shift"
+ pool_size: 20
+ stride: 1
+ tf2torch_tensor_name_prefix_torch: encoder
+ tf2torch_tensor_name_prefix_tf: EAND/speech_encoder
+speaker_encoder: conv
+speaker_encoder_conf:
+ input_units: 256
+ num_layers: 3
+ num_units: 256
+ kernel_size: 1
+ dropout_rate: 0.0
+ position_encoder: null
+ out_units: 256
+ out_norm: false
+ auxiliary_states: false
+ tf2torch_tensor_name_prefix_torch: speaker_encoder
+ tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
+ci_scorer: dot
+ci_scorer_conf: {}
+cd_scorer: san
+cd_scorer_conf:
+ input_size: 512
+ output_size: 512
+ out_units: 1
+ attention_heads: 4
+ linear_units: 1024
+ num_blocks: 4
+ dropout_rate: 0.0
+ positional_dropout_rate: 0.0
+ attention_dropout_rate: 0.0
+ # use string "null" to remove input layer
+ input_layer: "null"
+ pos_enc_class: null
+ normalize_before: true
+ tf2torch_tensor_name_prefix_torch: cd_scorer
+ tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
+# post net
+decoder: fsmn
+decoder_conf:
+ in_units: 32
+ out_units: 2517
+ filter_size: 31
+ fsmn_num_layers: 6
+ dnn_num_layers: 1
+ num_memory_units: 512
+ ffn_inner_dim: 512
+ dropout_rate: 0.0
+ tf2torch_tensor_name_prefix_torch: decoder
+ tf2torch_tensor_name_prefix_tf: EAND/post_net
+num_worker_count: 1
+required:
+- output_dir
+- token_list
+oss_bucket: 'null'
+version: 0.1.6
diff --git a/egs/alimeeting/diarization/sond/infer_alimeeting_test.py b/egs/alimeeting/diarization/sond/infer_alimeeting_test.py
new file mode 100644
index 0000000..0988f5d
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/infer_alimeeting_test.py
@@ -0,0 +1,24 @@
+from funasr.bin.diar_inference_launch import inference_launch
+import sys
+
+
+def main():
+ diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
+ diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pth"
+ output_dir = sys.argv[3] if len(sys.argv) > 3 else "./outputs"
+ data_path_and_name_and_type = [
+ ("data/test_rmsil/feats.scp", "speech", "kaldi_ark"),
+ ("data/test_rmsil/test_rmsil_tdnn6_xvec.scp", "profile", "kaldi_ark"),
+ ]
+ pipeline = inference_launch(
+ mode="sond",
+ diar_train_config=diar_config_path,
+ diar_model_file=diar_model_path,
+ output_dir=output_dir,
+ num_workers=1
+ )
+ pipeline(data_path_and_name_and_type)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py b/egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py
new file mode 100644
index 0000000..880f60f
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/local/convert_label_to_rttm.py
@@ -0,0 +1,132 @@
+import os
+from funasr.utils.job_runner import MultiProcessRunnerV3
+import numpy as np
+from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
+from collections import OrderedDict
+from tqdm import tqdm
+from scipy.ndimage import median_filter
+
+
+class MyRunner(MultiProcessRunnerV3):
+ def prepare(self, parser):
+ parser.add_argument("label_txt", type=str)
+ parser.add_argument("map_scp", type=str)
+ parser.add_argument("out_rttm", type=str)
+ parser.add_argument("--n_spk", type=int, default=4)
+ parser.add_argument("--chunk_len", type=int, default=1600)
+ parser.add_argument("--shift_len", type=int, default=400)
+ parser.add_argument("--ignore_len", type=int, default=5)
+ parser.add_argument("--smooth_size", type=int, default=7)
+ parser.add_argument("--vote_prob", type=float, default=0.5)
+ args = parser.parse_args()
+
+ if not os.path.exists(os.path.dirname(args.out_rttm)):
+ os.makedirs(os.path.dirname(args.out_rttm))
+
+ utt2labels = load_scp_as_list(args.label_txt, 'list')
+ utt2labels = sorted(utt2labels, key=lambda x: x[0])
+ meeting2map = load_scp_as_dict(args.map_scp)
+ meeting2labels = OrderedDict()
+ for utt_id, chunk_label in utt2labels:
+ mid = utt_id.split("-")[0]
+ if mid not in meeting2labels:
+ meeting2labels[mid] = []
+ meeting2labels[mid].append(chunk_label)
+ task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()]
+
+ return task_list, None, args
+
+ def post(self, result_list, args):
+ with open(args.out_rttm, "wt") as fd:
+ for results in result_list:
+ fd.writelines(results)
+
+
+def int2vec(x, vec_dim=8, dtype=np.int):
+ b = ('{:0' + str(vec_dim) + 'b}').format(x)
+ # little-endian order: lower bit first
+ return (np.array(list(b)[::-1]) == '1').astype(dtype)
+
+
+def seq2arr(seq, vec_dim=8):
+ return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+
+
+def sample2ms(sample, sr=16000):
+ return int(float(sample) / sr * 100)
+
+
+def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
+ n_chunk = len(chunk_label_list)
+ last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
+ n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
+ multi_labels = np.zeros((n_frame, n_spk), dtype=float)
+ weight = np.zeros((n_frame, 1), dtype=float)
+ for i in range(n_chunk):
+ raw_label = chunk_label_list[i]
+ for k in range(len(raw_label)):
+ if raw_label[k] == '<unk>':
+ raw_label[k] = raw_label[k-1] if k > 0 else '0'
+ chunk_multi_label = seq2arr(raw_label, n_spk)
+ chunk_len = chunk_multi_label.shape[0]
+ multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
+ weight[i*shift_len:i*shift_len+chunk_len, :] += 1
+ multi_labels = multi_labels / weight # normalizing vote
+ multi_labels = (multi_labels > vote_prob).astype(int) # voting results
+ return multi_labels
+
+
+def calc_spk_turns(label_arr, spk_list):
+ turn_list = []
+ length = label_arr.shape[0]
+ n_spk = label_arr.shape[1]
+ for k in range(n_spk):
+ if spk_list[k] == "None":
+ continue
+ in_utt = False
+ start = 0
+ for i in range(length):
+ if label_arr[i, k] == 1 and in_utt is False:
+ start = i
+ in_utt = True
+ if label_arr[i, k] == 0 and in_utt is True:
+ turn_list.append([spk_list[k], start, i - start])
+ in_utt = False
+ if in_utt:
+ turn_list.append([spk_list[k], start, length - start])
+ return turn_list
+
+
+def smooth_multi_labels(multi_label, win_len):
+ multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
+ return multi_label
+
+
+def process(task_args):
+ _, task_list, _, args = task_args
+ spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
+ template = "SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
+ results = []
+ for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
+ utt2map = load_scp_as_list(map_file_path, 'list')
+ multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
+ multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
+ org_len = sample2ms(int(utt2map[-1][1][1]), args.sr)
+ org_multi_labels = np.zeros((org_len, args.n_spk))
+ for seg_id, [org_st, org_ed, st, ed] in utt2map:
+ org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr)
+ st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr)
+ ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0])
+ org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :]
+ spk_turns = calc_spk_turns(org_multi_labels, spk_list)
+ spk_turns = sorted(spk_turns, key=lambda x: x[1])
+ for spk, st, dur in spk_turns:
+ # TODO: handle the leak of segments at the change points
+ if dur > args.ignore_len:
+ results.append(template.format(mid, float(st)/100, float(dur)/100, spk))
+ return results
+
+
+if __name__ == '__main__':
+ my_runner = MyRunner(process)
+ my_runner.run()
diff --git a/egs/alimeeting/diarization/sond/path.sh b/egs/alimeeting/diarization/sond/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/alimeeting/diarization/sond/run.sh b/egs/alimeeting/diarization/sond/run.sh
new file mode 100644
index 0000000..7e9a7f7
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/run.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+. ./path.sh || exit 1;
+
+stage=0
+stop_stage=2
+
+. utils/parse_options.sh || exit 1;
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "Downloading AliMeeting test set data..."
+ wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/alimeeting_test_data_for_sond.tar.gz
+ echo "Done. Extracting data..."
+ tar zxf alimeeting_test_data_for_sond.tar.gz
+ echo "Done."
+
+ echo "Downloading Pre-trained model..."
+ git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
+ git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
+ ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pth
+ cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
+ ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pth
+ cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
+ cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
+ echo "Done."
+
+ echo "Downloading dscore for scoring..."
+ git clone https://github.com/nryant/dscore.git
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "Calculating diarization results..."
+ python infer_alimeeting_test.py sond_fbank.yaml sond.pth outputs
+ python local/convert_label_to_rttm.py \
+ outputs/labels.txt \
+ data/test_rmsil/raw_rmsil_map.scp \
+ outputs/prediction_sm_83.rttm \
+ --ignore_len 10 --no_pbar --smooth_size 83 \
+ --vote_prob 0.5 --n_spk 16
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "Scoring..."
+ python dscore/score.py \
+ -r data/test_rmsil/test_org.crttm \
+ -s outputs/prediction_sm_83.rttm \
+ --collar 0.25
+fi
diff --git a/egs/alimeeting/diarization/sond/unit_test.py b/egs/alimeeting/diarization/sond/unit_test.py
new file mode 100644
index 0000000..84a4247
--- /dev/null
+++ b/egs/alimeeting/diarization/sond/unit_test.py
@@ -0,0 +1,97 @@
+from funasr.bin.diar_inference_launch import inference_launch
+import os
+
+
+def test_fbank_cpu_infer():
+ diar_config_path = "config_fbank.yaml"
+ diar_model_path = "sond.pth"
+ output_dir = "./outputs"
+ data_path_and_name_and_type = [
+ ("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
+ ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
+ ]
+ pipeline = inference_launch(
+ mode="sond",
+ diar_train_config=diar_config_path,
+ diar_model_file=diar_model_path,
+ output_dir=output_dir,
+ num_workers=1,
+ log_level="WARNING",
+ )
+ results = pipeline(data_path_and_name_and_type)
+ print(results)
+
+
+def test_fbank_gpu_infer():
+ diar_config_path = "config_fbank.yaml"
+ diar_model_path = "sond.pth"
+ output_dir = "./outputs"
+ data_path_and_name_and_type = [
+ ("data/unit_test/test_feats.scp", "speech", "kaldi_ark"),
+ ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
+ ]
+ pipeline = inference_launch(
+ mode="sond",
+ diar_train_config=diar_config_path,
+ diar_model_file=diar_model_path,
+ output_dir=output_dir,
+ ngpu=1,
+ num_workers=1,
+ log_level="WARNING",
+ )
+ results = pipeline(data_path_and_name_and_type)
+ print(results)
+
+
+def test_wav_gpu_infer():
+ diar_config_path = "config.yaml"
+ diar_model_path = "sond.pth"
+ output_dir = "./outputs"
+ data_path_and_name_and_type = [
+ ("data/unit_test/test_wav.scp", "speech", "sound"),
+ ("data/unit_test/test_profile.scp", "profile", "kaldi_ark"),
+ ]
+ pipeline = inference_launch(
+ mode="sond",
+ diar_train_config=diar_config_path,
+ diar_model_file=diar_model_path,
+ output_dir=output_dir,
+ ngpu=1,
+ num_workers=1,
+ log_level="WARNING",
+ )
+ results = pipeline(data_path_and_name_and_type)
+ print(results)
+
+
+def test_without_profile_gpu_infer():
+ diar_config_path = "config.yaml"
+ diar_model_path = "sond.pth"
+ output_dir = "./outputs"
+ raw_inputs = [[
+ "data/unit_test/raw_inputs/record.wav",
+ "data/unit_test/raw_inputs/spk1.wav",
+ "data/unit_test/raw_inputs/spk2.wav",
+ "data/unit_test/raw_inputs/spk3.wav",
+ "data/unit_test/raw_inputs/spk4.wav"
+ ]]
+ pipeline = inference_launch(
+ mode="sond_demo",
+ diar_train_config=diar_config_path,
+ diar_model_file=diar_model_path,
+ output_dir=output_dir,
+ ngpu=1,
+ num_workers=1,
+ log_level="WARNING",
+ param_dict={},
+ )
+ results = pipeline(raw_inputs=raw_inputs)
+ print(results)
+
+
+if __name__ == '__main__':
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+ test_fbank_cpu_infer()
+ test_fbank_gpu_infer()
+ test_wav_gpu_infer()
+ test_without_profile_gpu_infer()
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md
new file mode 100644
index 0000000..c2e4354
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md
@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+ - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+ - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+ - <strong>max_epoch:</strong> # number of training epoch
+ - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+ python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>output_dir:</strong> # result dir
+ - <strong>ngpu:</strong> # the number of GPUs for decoding
+ - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed~~~~
+ - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+ python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py
new file mode 100644
index 0000000..a5f1ee4
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py
@@ -0,0 +1,37 @@
+import os
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+
+def modelscope_finetune(params):
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
+ # dataset split ["train", "validation"]
+ ds_dict = MsDataset.load(params.data_path)
+ kwargs = dict(
+ model=params.model,
+ data_dir=ds_dict,
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
+ trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ params = modelscope_args(model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
+ data_path="./data")
+ params.output_dir = "./checkpoint"
+ params.data_path = "./example_data/"
+ params.dataset_type = "small"
+ params.batch_bins = 16000
+ params.max_epoch = 50
+ params.lr = 0.00002
+
+ modelscope_finetune(params)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
new file mode 100644
index 0000000..c016c19
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
@@ -0,0 +1,87 @@
+import os
+import shutil
+from multiprocessing import Pool
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+ output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+ gpu_id = (int(idx) - 1) // njob
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+ inference_pipline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
+ output_dir=output_dir_job,
+ )
+ audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+ inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+ # prepare for multi-GPU decoding
+ ngpu = params["ngpu"]
+ njob = params["njob"]
+ output_dir = params["output_dir"]
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+ split_dir = os.path.join(output_dir, "split")
+ os.mkdir(split_dir)
+ nj = ngpu * njob
+ wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(modelscope_infer_core,
+ args=(output_dir, split_dir, njob, str(i + 1)))
+ p.close()
+ p.join()
+
+ # combine decoding results
+ best_recog_path = os.path.join(output_dir, "1best_recog")
+ os.mkdir(best_recog_path)
+ files = ["text", "token", "score"]
+ for file in files:
+ with open(os.path.join(best_recog_path, file), "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+ with open(job_file) as f_job:
+ lines = f_job.readlines()
+ f.writelines(lines)
+
+ # If text exists, compute CER
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(best_recog_path, "token")
+ compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
+
+
+if __name__ == "__main__":
+ params = {}
+ params["data_dir"] = "./data/test"
+ params["output_dir"] = "./results"
+ params["ngpu"] = 2
+ params["njob"] = 5
+ modelscope_infer(params)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
new file mode 100644
index 0000000..56c282c
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
@@ -0,0 +1,52 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+ # prepare for decoding
+ pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+ for file_name in params["required_files"]:
+ if file_name == "configuration.json":
+ with open(os.path.join(pretrained_model_path, file_name)) as f:
+ config_dict = json.load(f)
+ config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+ with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+ json.dump(config_dict, f, indent=4, separators=(',', ': '))
+ else:
+ shutil.copy(os.path.join(pretrained_model_path, file_name),
+ os.path.join(params["output_dir"], file_name))
+ decoding_path = os.path.join(params["output_dir"], "decode_results")
+ if os.path.exists(decoding_path):
+ shutil.rmtree(decoding_path)
+ os.mkdir(decoding_path)
+
+ # decoding
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=params["output_dir"],
+ output_dir=decoding_path,
+ )
+ audio_in = os.path.join(params["data_dir"], "wav.scp")
+ inference_pipeline(audio_in=audio_in)
+
+ # computer CER if GT text is set
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
+
+
+if __name__ == '__main__':
+ params = {}
+ params["modelscope_model_name"] = "damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k"
+ params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
+ params["output_dir"] = "./checkpoint"
+ params["data_dir"] = "./data/test"
+ params["decoding_model_name"] = "valid.cer_ctc.ave.pth"
+ modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..5eeae37
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,23 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Fri Feb 10 13:34:24 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.6`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-1
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset CER(%) | base model|finetune model |
+|:--------------:|:---------:|:-------------:|
+| dev | 1.75 |1.62 |
+| test | 1.95 |1.78 |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..71d9fee
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,25 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Fri Feb 10 13:34:24 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.6`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-2
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | base model|finetune model|
+|:------------:|:---------:|:------------:|
+| dev_ios | 2.80 |2.60 |
+| test_android | 3.13 |2.84 |
+| test_ios | 2.85 |2.82 |
+| test_mic | 3.06 |2.88 |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..ec95be3
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,75 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Tue Nov 22 18:48:39 CST 2022`
+- python version: `3.7.12`
+- FunASR version: `0.1.0`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-1
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | CER(%)|
+|:---------:|:-----:|
+| dev | 1.75 |
+| test | 1.95 |
+
+## AISHELL-2
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | CER(%)|
+|:------------:|:-----:|
+| dev_ios | 2.80 |
+| test_android | 3.13 |
+| test_ios | 2.85 |
+| test_mic | 3.06 |
+
+## Wenetspeech
+- Decode config:
+ - Decode without CTC
+ - Decode without LM
+
+| testset | CER(%)|
+|:---------:|:-----:|
+| dev | 3.57 |
+| test | 6.97 |
+| test_net | 6.74 |
+
+## SpeechIO TIOBE
+- Decode config 1:
+ - Decode without CTC
+ - Decode without LM
+ - With text norm
+- Decode config 2:
+ - Decode without CTC
+ - Decode with Transformer-LM
+ - LM weight: 0.15
+ - With text norm
+
+| testset | w/o LM | w/ LM |
+|:------------------:|:----:|:----:|
+|SPEECHIO_ASR_ZH00001| 0.49 | 0.35 |
+|SPEECHIO_ASR_ZH00002| 3.23 | 2.86 |
+|SPEECHIO_ASR_ZH00003| 1.13 | 0.80 |
+|SPEECHIO_ASR_ZH00004| 1.33 | 1.10 |
+|SPEECHIO_ASR_ZH00005| 1.41 | 1.18 |
+|SPEECHIO_ASR_ZH00006| 5.25 | 4.85 |
+|SPEECHIO_ASR_ZH00007| 5.51 | 4.97 |
+|SPEECHIO_ASR_ZH00008| 3.69 | 3.18 |
+|SPEECHIO_ASR_ZH00009| 3.02 | 2.78 |
+|SPEECHIO_ASR_ZH000010| 3.35 | 2.99 |
+|SPEECHIO_ASR_ZH000011| 1.54 | 1.25 |
+|SPEECHIO_ASR_ZH000012| 2.06 | 1.68 |
+|SPEECHIO_ASR_ZH000013| 2.57 | 2.25 |
+|SPEECHIO_ASR_ZH000014| 3.86 | 3.08 |
+|SPEECHIO_ASR_ZH000015| 3.34 | 2.67 |
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md
new file mode 100644
index 0000000..dfd509d
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md
@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+ - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+ - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+ - <strong>max_epoch:</strong> # number of training epoch
+ - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+ python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>output_dir:</strong> # result dir
+ - <strong>ngpu:</strong> # the number of GPUs for decoding
+ - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+ python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py
index 1aef9c6..2ecc229 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py
@@ -1,35 +1,36 @@
import os
+
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
+
from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
+ ds_dict = MsDataset.load(params.data_path)
kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
+ model=params.model,
data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline"
- params["model_revision"] = None
+ params = modelscope_args(model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline", data_path="./data")
+ params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
+ params.data_path = "./example_data/" # 鏁版嵁璺緞
+ params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+ params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+
modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
index 85ddeee..3a89546 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
@@ -1,13 +1,89 @@
+import os
+import shutil
+from multiprocessing import Pool
+
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_fa.wav"
- output_dir = "./results"
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+ output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+ gpu_id = (int(idx) - 1) // njob
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
inference_pipline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline",
- output_dir=output_dir,
+ output_dir=output_dir_job,
+ batch_size=1
)
- rec_result = inference_pipline(audio_in=audio_in)
- print(rec_result)
+ audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+ inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+ # prepare for multi-GPU decoding
+ ngpu = params["ngpu"]
+ njob = params["njob"]
+ output_dir = params["output_dir"]
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+ split_dir = os.path.join(output_dir, "split")
+ os.mkdir(split_dir)
+ nj = ngpu * njob
+ wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(modelscope_infer_core,
+ args=(output_dir, split_dir, njob, str(i + 1)))
+ p.close()
+ p.join()
+
+ # combine decoding results
+ best_recog_path = os.path.join(output_dir, "1best_recog")
+ os.mkdir(best_recog_path)
+ files = ["text", "token", "score"]
+ for file in files:
+ with open(os.path.join(best_recog_path, file), "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+ with open(job_file) as f_job:
+ lines = f_job.readlines()
+ f.writelines(lines)
+
+ # If text exists, compute CER
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(best_recog_path, "token")
+ compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
+ os.system("tail -n 3 {}".format(os.path.join(best_recog_path, "text.cer")))
+
+
+if __name__ == "__main__":
+ params = {}
+ params["data_dir"] = "./data/test"
+ params["output_dir"] = "./results"
+ params["ngpu"] = 1
+ params["njob"] = 8
+ modelscope_infer(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py
new file mode 100644
index 0000000..d91a40a
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py
@@ -0,0 +1,54 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+ # prepare for decoding
+ pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+ for file_name in params["required_files"]:
+ if file_name == "configuration.json":
+ with open(os.path.join(pretrained_model_path, file_name)) as f:
+ config_dict = json.load(f)
+ config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+ with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+ json.dump(config_dict, f, indent=4, separators=(',', ': '))
+ else:
+ shutil.copy(os.path.join(pretrained_model_path, file_name),
+ os.path.join(params["output_dir"], file_name))
+ decoding_path = os.path.join(params["output_dir"], "decode_results")
+ if os.path.exists(decoding_path):
+ shutil.rmtree(decoding_path)
+ os.mkdir(decoding_path)
+
+ # decoding
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=params["output_dir"],
+ output_dir=decoding_path,
+ batch_size=1
+ )
+ audio_in = os.path.join(params["data_dir"], "wav.scp")
+ inference_pipeline(audio_in=audio_in)
+
+ # computer CER if GT text is set
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
+ os.system("tail -n 3 {}".format(os.path.join(decoding_path, "text.cer")))
+
+
+if __name__ == '__main__':
+ params = {}
+ params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline"
+ params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
+ params["output_dir"] = "./checkpoint"
+ params["data_dir"] = "./data/test"
+ params["decoding_model_name"] = "20epoch.pth"
+ modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/README.md
new file mode 100644
index 0000000..dfd509d
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/README.md
@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+ - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+ - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+ - <strong>max_epoch:</strong> # number of training epoch
+ - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+ python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>output_dir:</strong> # result dir
+ - <strong>ngpu:</strong> # the number of GPUs for decoding
+ - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+ - <strong>output_dir:</strong> # result dir
+ - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+ - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+ python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/finetune.py
index 3bdf1cc..2469e53 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/finetune.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/finetune.py
@@ -1,35 +1,36 @@
import os
+
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
+
from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
+ if not os.path.exists(params.output_dir):
+ os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
+ ds_dict = MsDataset.load(params.data_path)
kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
+ model=params.model,
data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
+ dataset_type=params.dataset_type,
+ work_dir=params.output_dir,
+ batch_bins=params.batch_bins,
+ max_epoch=params.max_epoch,
+ lr=params.lr)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online"
- params["model_revision"] = None
+ params = modelscope_args(model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online", data_path="./data")
+ params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
+ params.data_path = "./example_data/" # 鏁版嵁璺緞
+ params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+ params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+
modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
index 960c393..ecb1381 100644
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer.py
@@ -1,13 +1,89 @@
+import os
+import shutil
+from multiprocessing import Pool
+
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_fa.wav"
- output_dir = "./results"
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+ output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+ gpu_id = (int(idx) - 1) // njob
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+ else:
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
inference_pipline = pipeline(
task=Tasks.auto_speech_recognition,
model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online",
- output_dir=output_dir,
+ output_dir=output_dir_job,
+ batch_size=1
)
- rec_result = inference_pipline(audio_in=audio_in)
- print(rec_result)
+ audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+ inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+ # prepare for multi-GPU decoding
+ ngpu = params["ngpu"]
+ njob = params["njob"]
+ output_dir = params["output_dir"]
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+ split_dir = os.path.join(output_dir, "split")
+ os.mkdir(split_dir)
+ nj = ngpu * njob
+ wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(modelscope_infer_core,
+ args=(output_dir, split_dir, njob, str(i + 1)))
+ p.close()
+ p.join()
+
+ # combine decoding results
+ best_recog_path = os.path.join(output_dir, "1best_recog")
+ os.mkdir(best_recog_path)
+ files = ["text", "token", "score"]
+ for file in files:
+ with open(os.path.join(best_recog_path, file), "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+ with open(job_file) as f_job:
+ lines = f_job.readlines()
+ f.writelines(lines)
+
+ # If text exists, compute CER
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(best_recog_path, "token")
+ compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
+ os.system("tail -n 3 {}".format(os.path.join(best_recog_path, "text.cer")))
+
+
+if __name__ == "__main__":
+ params = {}
+ params["data_dir"] = "./data/test"
+ params["output_dir"] = "./results"
+ params["ngpu"] = 1
+ params["njob"] = 8
+ modelscope_infer(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer_after_finetune.py
new file mode 100644
index 0000000..f9fb0db
--- /dev/null
+++ b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/infer_after_finetune.py
@@ -0,0 +1,54 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+ # prepare for decoding
+ pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+ for file_name in params["required_files"]:
+ if file_name == "configuration.json":
+ with open(os.path.join(pretrained_model_path, file_name)) as f:
+ config_dict = json.load(f)
+ config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+ with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+ json.dump(config_dict, f, indent=4, separators=(',', ': '))
+ else:
+ shutil.copy(os.path.join(pretrained_model_path, file_name),
+ os.path.join(params["output_dir"], file_name))
+ decoding_path = os.path.join(params["output_dir"], "decode_results")
+ if os.path.exists(decoding_path):
+ shutil.rmtree(decoding_path)
+ os.mkdir(decoding_path)
+
+ # decoding
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=params["output_dir"],
+ output_dir=decoding_path,
+ batch_size=1
+ )
+ audio_in = os.path.join(params["data_dir"], "wav.scp")
+ inference_pipeline(audio_in=audio_in)
+
+ # computer CER if GT text is set
+ text_in = os.path.join(params["data_dir"], "text")
+ if os.path.exists(text_in):
+ text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+ compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
+ os.system("tail -n 3 {}".format(os.path.join(decoding_path, "text.cer")))
+
+
+if __name__ == '__main__':
+ params = {}
+ params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online"
+ params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
+ params["output_dir"] = "./checkpoint"
+ params["data_dir"] = "./data/test"
+ params["decoding_model_name"] = "20epoch.pth"
+ modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
new file mode 100644
index 0000000..ed3b7e2
--- /dev/null
+++ b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
@@ -0,0 +1,17 @@
+
+
+##################text浜岃繘鍒舵暟鎹�#####################
+inputs = "hello 澶� 瀹� 濂� 鍛�"
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipline = pipeline(
+ task=Tasks.language_model,
+ model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
+ output_dir="./tmp/"
+)
+
+rec_result = inference_pipline(text_in=inputs)
+print(rec_result)
+
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 16fa3e5..ca8f2bc 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -453,7 +453,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 3769b6c..6c5acfc 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -3,6 +3,9 @@
import logging
import sys
import time
+import copy
+import os
+import codecs
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -35,6 +38,8 @@
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -78,6 +83,7 @@
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
**kwargs,
):
assert check_argument_types()
@@ -168,6 +174,34 @@
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from file or str
+ if hotword_list_or_file is None:
+ self.hotword_list = None
+ elif os.path.exists(hotword_list_or_file):
+ self.hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([1])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ else:
+ logging.info("Attempting to parse hotwords as str...")
+ self.hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([1])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+
+
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
beam_search = None
@@ -229,8 +263,14 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ if not isinstance(self.asr_model, ContextualParaformer):
+ if self.hotword_list:
+ logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ else:
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
@@ -388,6 +428,11 @@
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ else:
+ hotword_list_or_file = None
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
@@ -416,6 +461,7 @@
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
@@ -497,7 +543,7 @@
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
@@ -551,7 +597,12 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -679,8 +730,10 @@
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
+ param_dict = {'hotword': args.hotword}
kwargs = vars(args)
kwargs.pop("config", None)
+ kwargs['param_dict'] = param_dict
inference(**kwargs)
diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py
index 7e2e414..7da48e2 100644
--- a/funasr/bin/asr_inference_paraformer_timestamp.py
+++ b/funasr/bin/asr_inference_paraformer_timestamp.py
@@ -436,7 +436,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 2832504..dbb2719 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -241,6 +241,11 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
finish_count = 0
file_count = 1
@@ -284,8 +289,10 @@
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
-
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
@@ -293,9 +300,11 @@
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
- text_postprocessed_punc = text_postprocessed
- if len(word_lists) > 0 and text2punc is not None:
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+ text_postprocessed_punc = text_postprocessed
+ if len(word_lists) > 0 and text2punc is not None:
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 1d09c79..c4bb61b 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -14,6 +14,7 @@
from typing import Any
from typing import List
import math
+import copy
import numpy as np
import torch
from typeguard import check_argument_types
@@ -38,8 +39,9 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
+from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
+from funasr.models.e2e_asr_paraformer import BiCifParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -234,6 +236,10 @@
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ if isinstance(self.asr_model, BiCifParaformer):
+ _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+ pre_token_length) # test no bias cif2
+
results = []
b, n, d = decoder_out.size()
for i in range(b):
@@ -276,9 +282,12 @@
else:
text = None
- time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time)
-
- results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+ if isinstance(self.asr_model, BiCifParaformer):
+ timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
+ else:
+ time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time)
+ results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
@@ -561,6 +570,11 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
finish_count = 0
file_count = 1
@@ -603,8 +617,11 @@
result = result_segments[0]
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
-
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
@@ -612,9 +629,12 @@
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
- text_postprocessed_punc = text_postprocessed
- if len(word_lists) > 0 and text2punc is not None:
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+
+ text_postprocessed_punc = text_postprocessed
+ if len(word_lists) > 0 and text2punc is not None:
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index cfec9a0..0a5824c 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -492,7 +492,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index cfec9a0..0a5824c 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -492,7 +492,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
new file mode 100755
index 0000000..c3e210b
--- /dev/null
+++ b/funasr/bin/diar_inference_launch.py
@@ -0,0 +1,179 @@
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+from typing import Union, Dict, Any
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Speaker Verification",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=False)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--njob",
+ type=int,
+ default=1,
+ help="The number of jobs for each gpu",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=True)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--vad_infer_config",
+ type=str,
+ help="VAD infer configuration",
+ )
+ group.add_argument(
+ "--vad_model_file",
+ type=str,
+ help="VAD model parameter file",
+ )
+ group.add_argument(
+ "--diar_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--diar_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global CMVN file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("The inference configuration related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument(
+ "--diar_smooth_size",
+ type=int,
+ default=121,
+ help="The smoothing size for post-processing"
+ )
+
+ return parser
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "sond":
+ from funasr.bin.sond_inference import inference_modelscope
+ return inference_modelscope(**kwargs)
+ elif mode == "sond_demo":
+ from funasr.bin.sond_inference import inference_modelscope
+ param_dict = {
+ "extract_profile": True,
+ "sv_train_config": "sv.yaml",
+ "sv_model_file": "sv.pth",
+ }
+ if "param_dict" in kwargs:
+ kwargs["param_dict"].update(param_dict)
+ else:
+ kwargs["param_dict"] = param_dict
+ return inference_modelscope(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="sond",
+ help="The decoding mode",
+ )
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+
+ # set logging messages
+ logging.basicConfig(
+ level=args.log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # gpu setting
+ if args.ngpu > 0:
+ jobid = int(args.output_dir.split(".")[-1])
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+ inference_launch(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
index 27a8a71..198d578 100755
--- a/funasr/bin/lm_calc_perplexity.py
+++ b/funasr/bin/lm_calc_perplexity.py
@@ -56,7 +56,7 @@
set_all_random_seed(seed)
# 2. Build LM
- model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
+ model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
# Wrape model to make model.nll() data-parallel
wrapped_model = ForwardAdaptor(model, "nll")
wrapped_model.to(dtype=getattr(torch, dtype)).eval()
@@ -111,6 +111,7 @@
utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
# Write PPL of each utts for debugging or analysis
+ writer["utt2nll"][key] = str(-_nll)
writer["utt2ppl"][key] = str(utt_ppl)
writer["utt2ntokens"][key] = str(ntoken)
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
new file mode 100644
index 0000000..909cb02
--- /dev/null
+++ b/funasr/bin/lm_inference.py
@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+def inference(
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float],
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ output_dir=output_dir,
+ raw_inputs=raw_inputs,
+ batch_size=batch_size,
+ dtype=dtype,
+ ngpu=ngpu,
+ seed=seed,
+ num_workers=num_workers,
+ log_level=log_level,
+ key_file=key_file,
+ train_config=train_config,
+ model_file=model_file,
+ log_base = log_base,
+ allow_variable_data_keys = allow_variable_data_keys,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build Model
+ model, train_args = LMTask.build_model_from_file(
+ train_config, model_file, device)
+ wrapped_model = ForwardAdaptor(model, "nll")
+ wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ logging.info(f"Model:\n{model}")
+
+ preprocessor = LMPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file
+ )
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ if output_dir_v2 is not None:
+ writer = DatadirWriter(output_dir_v2)
+ else:
+ writer = None
+
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "lm demo"
+ if line=="":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ batch = {}
+ batch['text'] = line
+ if preprocessor != None:
+ batch = preprocessor(key, batch)
+
+ # Force data-precision
+ for name in batch:
+ value = batch[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype("float32")
+ elif value.dtype.kind == "i":
+ value = value.astype("long")
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ batch[name] = value
+
+ batch["text_lengths"] = torch.from_numpy(
+ np.array([len(batch["text"])], dtype='int32'))
+ batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## compute ppl
+ ppl_out_batch = ""
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for sent_ids, sent_nll in zip(batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key+":\n"] = ppl_out
+ results.append(item)
+
+ return results
+
+ # 3. Build data-iterator
+ loader = LMTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=preprocessor,
+ collate_fn=LMTask.build_collate_fn(train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4. Start for-loop
+ total_nll = 0.0
+ total_ntokens = 0
+ ppl_out_all = ""
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ ppl_out_batch = ""
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ # NOTE(kamo): data_parallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## print ppl
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ utt2nll = round(-sent_nll_sum.item(), 5)
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key+":\n"] = ppl_out
+ writer["utt2nll"][key] = str(utt2nll)
+ results.append(item)
+
+ ppl_out_all += ppl_out_batch
+
+ assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+ # nll: (B, L) -> (B,)
+ nll = nll.detach().cpu().numpy().sum(1)
+ # lengths: (B,)
+ lengths = lengths.detach().cpu().numpy()
+ total_nll += nll.sum()
+ total_ntokens += lengths.sum()
+
+ if log_base is None:
+ ppl = np.exp(total_nll / total_ntokens)
+ else:
+ ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+ avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+ total_nll=round(-total_nll.item(), 4),
+ total_ppl=round(ppl.item(), 4)
+ )
+ item = {'key': 'AVG PPL', 'value': avg_ppl}
+ ppl_out_all += avg_ppl
+ if writer is not None:
+ writer["ppl"]["AVG PPL : "] = avg_ppl
+ results.append(item)
+
+ return results
+
+ return _forward
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Calc perplexity",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=False)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument(
+ "--log_base",
+ type=float_or_none,
+ default=10,
+ help="The base of logarithm for Perplexity. "
+ "If None, napier's constant is used.",
+ required=False
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ action="append",
+ required=False
+ )
+ group.add_argument(
+ "--raw_inputs",
+ type=str,
+ required=False
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group.add_argument("--split_with_space", type=str2bool, default=False)
+ group.add_argument("--seg_dict_file", type=str_or_none)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument("--train_config", type=str)
+ group.add_argument("--model_file", type=str)
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ inference(**kwargs)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
new file mode 100644
index 0000000..492ebab
--- /dev/null
+++ b/funasr/bin/lm_inference_launch.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import os
+import sys
+from typing import Union, Dict, Any
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.types import float_or_none
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Calc perplexity",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument("--gpuid_list", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--njob", type=int, default=1, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument(
+ "--log_base",
+ type=float_or_none,
+ default=10,
+ help="The base of logarithm for Perplexity. "
+ "If None, napier's constant is used.",
+ required=False
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ action="append",
+ required=False
+ )
+ group.add_argument(
+ "--raw_inputs",
+ type=str,
+ required=False
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group.add_argument("--split_with_space", type=str2bool, default=False)
+ group.add_argument("--seg_dict_file", type=str_or_none)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument("--train_config", type=str)
+ group.add_argument("--model_file", type=str)
+ group.add_argument("--mode", type=str, default="lm")
+ return parser
+
+def inference_launch(mode, **kwargs):
+ if mode == "transformer":
+ from funasr.bin.lm_inference import inference_modelscope
+ return inference_modelscope(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+
+ # set logging messages
+ logging.basicConfig(
+ level=args.log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # gpu setting
+ if args.ngpu > 0:
+ jobid = int(args.output_dir.split(".")[-1])
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+ kwargs.pop("gpuid_list", None)
+ kwargs.pop("njob", None)
+ results = inference_launch(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index faa7a45..8641465 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,22 +1,46 @@
#!/usr/bin/env python3
+
+import os
+
from funasr.tasks.lm import LMTask
-def get_parser():
+# for LM Training
+def parse_args():
parser = LMTask.get_parser()
- return parser
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
-def main(cmd=None):
- """LM training.
-
- Example:
-
- % python lm_train.py asr --print_config --optim adadelta
- % python lm_train.py --config conf/train_asr.yaml
- """
- LMTask.main(cmd=cmd)
+def main(args=None, cmd=None):
+ # for LM Training
+ LMTask.main(args=args, cmd=cmd)
-if __name__ == "__main__":
- main()
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small" and args.ngpu != 0:
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
new file mode 100755
index 0000000..299de0d
--- /dev/null
+++ b/funasr/bin/sond_inference.py
@@ -0,0 +1,544 @@
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+from collections import OrderedDict
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.diar import DiarTask
+from funasr.tasks.asr import ASRTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from scipy.ndimage import median_filter
+from funasr.utils.misc import statistic_model_parameters
+
+class Speech2Diarization:
+ """Speech2Xvector class
+
+ Examples:
+ >>> import soundfile
+ >>> import numpy as np
+ >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
+ >>> profile = np.load("profiles.npy")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2diar(audio, profile)
+ {"spk1": [(int, int), ...], ...}
+
+ """
+
+ def __init__(
+ self,
+ diar_train_config: Union[Path, str] = None,
+ diar_model_file: Union[Path, str] = None,
+ device: str = "cpu",
+ batch_size: int = 1,
+ dtype: str = "float32",
+ streaming: bool = False,
+ smooth_size: int = 83,
+ dur_threshold: float = 10,
+ ):
+ assert check_argument_types()
+
+ # TODO: 1. Build Diarization model
+ diar_model, diar_train_args = DiarTask.build_model_from_file(
+ config_file=diar_train_config,
+ model_file=diar_model_file,
+ device=device
+ )
+ logging.info("diar_model: {}".format(diar_model))
+ logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
+ logging.info("diar_train_args: {}".format(diar_train_args))
+ diar_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.diar_model = diar_model
+ self.diar_train_args = diar_train_args
+ self.token_list = diar_train_args.token_list
+ self.smooth_size = smooth_size
+ self.dur_threshold = dur_threshold
+ self.device = device
+ self.dtype = dtype
+
+ def smooth_multi_labels(self, multi_label):
+ multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
+ return multi_label
+
+ @staticmethod
+ def calc_spk_turns(label_arr, spk_list):
+ turn_list = []
+ length = label_arr.shape[0]
+ n_spk = label_arr.shape[1]
+ for k in range(n_spk):
+ if spk_list[k] == "None":
+ continue
+ in_utt = False
+ start = 0
+ for i in range(length):
+ if label_arr[i, k] == 1 and in_utt is False:
+ start = i
+ in_utt = True
+ if label_arr[i, k] == 0 and in_utt is True:
+ turn_list.append([spk_list[k], start, i - start])
+ in_utt = False
+ if in_utt:
+ turn_list.append([spk_list[k], start, length - start])
+ return turn_list
+
+ @staticmethod
+ def seq2arr(seq, vec_dim=8):
+ def int2vec(x, vec_dim=8, dtype=np.int):
+ b = ('{:0' + str(vec_dim) + 'b}').format(x)
+ # little-endian order: lower bit first
+ return (np.array(list(b)[::-1]) == '1').astype(dtype)
+
+ return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+
+ def post_processing(self, raw_logits: torch.Tensor, spk_num: int):
+ logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
+ # upsampling outputs to match inputs
+ ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
+ logits_idx = F.upsample(
+ logits_idx.unsqueeze(1).float(),
+ size=(ut, ),
+ mode="nearest",
+ ).squeeze(1).long()
+ logits_idx = logits_idx[0].tolist()
+ pse_labels = [self.token_list[x] for x in logits_idx]
+ multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
+ multi_labels = self.smooth_multi_labels(multi_labels)
+ spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
+ spk_turns = self.calc_spk_turns(multi_labels, spk_list)
+ results = OrderedDict()
+ for spk, st, dur in spk_turns:
+ if spk not in results:
+ results[spk] = []
+ if dur > self.dur_threshold:
+ results[spk].append((st, st+dur))
+
+ # sort segments in start time ascending
+ for spk in results:
+ results[spk] = sorted(results[spk], key=lambda x: x[0])
+
+ return results, pse_labels
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ profile: Union[torch.Tensor, np.ndarray],
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ profile: Speaker profiles
+ Returns:
+ diarization results for each speaker
+
+ """
+ assert check_argument_types()
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if isinstance(profile, np.ndarray):
+ profile = torch.tensor(profile)
+
+ # data: (Nsamples,) -> (1, Nsamples)
+ speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
+ # lengths: (1,)
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
+ batch = {"speech": speech, "speech_lengths": speech_lengths,
+ "profile": profile, "profile_lengths": profile_lengths}
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ logits = self.diar_model.prediction_forward(**batch)
+ results, pse_labels = self.post_processing(logits, profile.shape[1])
+
+ return results, pse_labels
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ):
+ """Build Speech2Xvector instance from the pretrained model.
+
+ Args:
+ model_tag (Optional[str]): Model tag of the pretrained models.
+ Currently, the tags of espnet_model_zoo are supported.
+
+ Returns:
+ Speech2Xvector: Speech2Xvector instance.
+
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2Diarization(**kwargs)
+
+
+def inference_modelscope(
+ diar_train_config: str,
+ diar_model_file: str,
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 0,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ smooth_size: int = 83,
+ dur_threshold: int = 10,
+ out_format: str = "vad",
+ param_dict: Optional[dict] = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("param_dict: {}".format(param_dict))
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2a. Build speech2xvec [Optional]
+ if param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+ assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
+ assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
+ sv_train_config = param_dict["sv_train_config"]
+ sv_model_file = param_dict["sv_model_file"]
+ from funasr.bin.sv_inference import Speech2Xvector
+ speech2xvector_kwargs = dict(
+ sv_train_config=sv_train_config,
+ sv_model_file=sv_model_file,
+ device=device,
+ dtype=dtype,
+ streaming=streaming,
+ embedding_node="resnet1_dense"
+ )
+ logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
+ speech2xvector = Speech2Xvector.from_pretrained(
+ model_tag=model_tag,
+ **speech2xvector_kwargs,
+ )
+ speech2xvector.sv_model.eval()
+
+ # 2b. Build speech2diar
+ speech2diar_kwargs = dict(
+ diar_train_config=diar_train_config,
+ diar_model_file=diar_model_file,
+ device=device,
+ dtype=dtype,
+ streaming=streaming,
+ smooth_size=smooth_size,
+ dur_threshold=dur_threshold,
+ )
+ logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
+ speech2diar = Speech2Diarization.from_pretrained(
+ model_tag=model_tag,
+ **speech2diar_kwargs,
+ )
+ speech2diar.diar_model.eval()
+
+ def output_results_str(results: dict, uttid: str):
+ rst = []
+ mid = uttid.rsplit("-", 1)[0]
+ for key in results:
+ results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
+ if out_format == "vad":
+ for spk, segs in results.items():
+ rst.append("{} {}".format(spk, segs))
+ else:
+ template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
+ for spk, segs in results.items():
+ rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
+
+ return "\n".join(rst)
+
+ def _forward(
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str]]] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
+ ):
+ logging.info("param_dict: {}".format(param_dict))
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, (list, tuple)):
+ assert all([len(example) >= 2 for example in raw_inputs]), \
+ "The length of test case in raw_inputs must larger than 1 (>=2)."
+
+ def prepare_dataset():
+ for idx, example in enumerate(raw_inputs):
+ # read waveform file
+ example = [soundfile.read(x)[0] if isinstance(example[0], str) else x
+ for x in example]
+ # convert torch tensor to numpy array
+ example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
+ for x in example]
+ speech = example[0]
+ logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
+ profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
+ profile = torch.cat(profile, dim=0)
+ yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
+
+ loader = prepare_dataset()
+ else:
+ raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
+ else:
+ # 3. Build data-iterator
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=None,
+ collate_fn=None,
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 7. Start for-loop
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ os.makedirs(output_path, exist_ok=True)
+ output_writer = open("{}/result.txt".format(output_path), "w")
+ pse_label_writer = open("{}/labels.txt".format(output_path), "w")
+ logging.info("Start to diarize...")
+ result_list = []
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ results, pse_labels = speech2diar(**batch)
+ # Only supporting batch_size==1
+ key, value = keys[0], output_results_str(results, keys[0])
+ item = {"key": key, "value": value}
+ result_list.append(item)
+ if output_path is not None:
+ output_writer.write(value)
+ output_writer.flush()
+ pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
+ pse_label_writer.flush()
+
+ if output_path is not None:
+ output_writer.close()
+ pse_label_writer.close()
+
+ return result_list
+
+ return _forward
+
+
+def inference(
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ diar_train_config: Optional[str],
+ diar_model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 0,
+ seed: int = 0,
+ num_workers: int = 1,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ smooth_size: int = 83,
+ dur_threshold: int = 10,
+ out_format: str = "vad",
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ diar_train_config=diar_train_config,
+ diar_model_file=diar_model_file,
+ output_dir=output_dir,
+ batch_size=batch_size,
+ dtype=dtype,
+ ngpu=ngpu,
+ seed=seed,
+ num_workers=num_workers,
+ log_level=log_level,
+ key_file=key_file,
+ model_tag=model_tag,
+ allow_variable_data_keys=allow_variable_data_keys,
+ streaming=streaming,
+ smooth_size=smooth_size,
+ dur_threshold=dur_threshold,
+ out_format=out_format,
+ **kwargs,
+ )
+
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="Speaker verification/x-vector extraction",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=False)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--diar_train_config",
+ type=str,
+ help="diarization training configuration",
+ )
+ group.add_argument(
+ "--diar_model_file",
+ type=str,
+ help="diarization model parameter file",
+ )
+ group.add_argument(
+ "--dur_threshold",
+ type=int,
+ default=10,
+ help="The threshold for short segments in number frames"
+ )
+ parser.add_argument(
+ "--smooth_size",
+ type=int,
+ default=83,
+ help="The smoothing window length in number frames"
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ parser.add_argument("--streaming", type=str2bool, default=False)
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ logging.info("args: {}".format(kwargs))
+ if args.output_dir is None:
+ jobid, n_gpu = 1, 1
+ gpuid = args.gpuid_list.split(",")[jobid-1]
+ else:
+ jobid = int(args.output_dir.split(".")[-1])
+ n_gpu = len(args.gpuid_list.split(","))
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+ results_list = inference(**kwargs)
+ for results in results_list:
+ print("{} {}".format(results["key"], results["value"]))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
index 57ce91d..a78bccd 100755
--- a/funasr/bin/sv_inference.py
+++ b/funasr/bin/sv_inference.py
@@ -1,4 +1,7 @@
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
import os
@@ -26,7 +29,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-
+from funasr.utils.misc import statistic_model_parameters
class Speech2Xvector:
"""Speech2Xvector class
@@ -59,6 +62,7 @@
device=device
)
logging.info("sv_model: {}".format(sv_model))
+ logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
logging.info("sv_train_args: {}".format(sv_train_args))
sv_model.to(dtype=getattr(torch, dtype)).eval()
@@ -156,17 +160,17 @@
def inference_modelscope(
- output_dir: Optional[str],
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- sv_train_config: Optional[str],
- sv_model_file: Optional[str],
- model_tag: Optional[str],
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ sv_train_config: Optional[str] = "sv.yaml",
+ sv_model_file: Optional[str] = "sv.pth",
+ model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
embedding_node: str = "resnet1_dense",
@@ -214,7 +218,6 @@
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
- fs: dict = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index c511dc7..1205d19 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
new file mode 100755
index 0000000..dc565d0
--- /dev/null
+++ b/funasr/bin/tokenize_text.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+import argparse
+from collections import Counter
+import logging
+from pathlib import Path
+import sys
+from typing import List
+from typing import Optional
+
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def field2slice(field: Optional[str]) -> slice:
+ """Convert field string to slice
+
+ Note that field string accepts 1-based integer.
+
+ Examples:
+ >>> field2slice("1-")
+ slice(0, None, None)
+ >>> field2slice("1-3")
+ slice(0, 3, None)
+ >>> field2slice("-3")
+ slice(None, 3, None)
+ """
+ field = field.strip()
+ try:
+ if "-" in field:
+ # e.g. "2-" or "2-5" or "-7"
+ s1, s2 = field.split("-", maxsplit=1)
+ if s1.strip() == "":
+ s1 = None
+ else:
+ s1 = int(s1)
+ if s1 == 0:
+ raise ValueError("1-based string")
+ if s2.strip() == "":
+ s2 = None
+ else:
+ s2 = int(s2)
+ else:
+ # e.g. "2"
+ s1 = int(field)
+ s2 = s1 + 1
+ if s1 == 0:
+ raise ValueError("must be 1 or more value")
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
+
+ if s1 is None:
+ slic = slice(None, s2)
+ else:
+ # -1 because of 1-based integer following "cut" command
+ # e.g "1-3" -> slice(0, 3)
+ slic = slice(s1 - 1, s2)
+ return slic
+
+
+def tokenize(
+ input: str,
+ output: str,
+ field: Optional[str],
+ delimiter: Optional[str],
+ token_type: str,
+ space_symbol: str,
+ non_linguistic_symbols: Optional[str],
+ bpemodel: Optional[str],
+ log_level: str,
+ write_vocabulary: bool,
+ vocabulary_size: int,
+ remove_non_linguistic_symbols: bool,
+ cutoff: int,
+ add_symbol: List[str],
+ cleaner: Optional[str],
+ g2p: Optional[str],
+):
+ assert check_argument_types()
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ if input == "-":
+ fin = sys.stdin
+ else:
+ fin = Path(input).open("r", encoding="utf-8")
+ if output == "-":
+ fout = sys.stdout
+ else:
+ p = Path(output)
+ p.parent.mkdir(parents=True, exist_ok=True)
+ fout = p.open("w", encoding="utf-8")
+
+ cleaner = TextCleaner(cleaner)
+ tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+ g2p_type=g2p,
+ )
+
+ counter = Counter()
+ if field is not None:
+ field = field2slice(field)
+
+ for line in fin:
+ line = line.rstrip()
+ if field is not None:
+ # e.g. field="2-"
+ # uttidA hello world!! -> hello world!!
+ tokens = line.split(delimiter)
+ tokens = tokens[field]
+ if delimiter is None:
+ line = " ".join(tokens)
+ else:
+ line = delimiter.join(tokens)
+
+ line = cleaner(line)
+ tokens = tokenizer.text2tokens(line)
+ if not write_vocabulary:
+ fout.write(" ".join(tokens) + "\n")
+ else:
+ for t in tokens:
+ counter[t] += 1
+
+ if not write_vocabulary:
+ return
+
+ ## FIXME
+ ## del duplicate add_symbols in counter
+ for symbol_and_id in add_symbol:
+ # e.g symbol="<blank>:0"
+ try:
+ symbol, idx = symbol_and_id.split(":")
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+ symbol = symbol.strip()
+ if symbol in counter:
+ del counter[symbol]
+
+ # ======= write_vocabulary mode from here =======
+ # Sort by the number of occurrences in descending order
+ # and filter lower frequency words than cutoff value
+ words_and_counts = list(
+ filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
+ )
+ # Restrict the vocabulary size
+ if vocabulary_size > 0:
+ if vocabulary_size < len(add_symbol):
+ raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
+ words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
+
+ # Parse the values of --add_symbol
+ for symbol_and_id in add_symbol:
+ # e.g symbol="<blank>:0"
+ try:
+ symbol, idx = symbol_and_id.split(":")
+ idx = int(idx)
+ except ValueError:
+ raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+ symbol = symbol.strip()
+
+ # e.g. idx=0 -> append as the first symbol
+ # e.g. idx=-1 -> append as the last symbol
+ if idx < 0:
+ idx = len(words_and_counts) + 1 + idx
+ words_and_counts.insert(idx, (symbol, None))
+
+ # Write words
+ for w, c in words_and_counts:
+ fout.write(w + "\n")
+
+ # Logging
+ total_count = sum(counter.values())
+ invocab_count = sum(c for w, c in words_and_counts if c is not None)
+ logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser(
+ description="Tokenize texts",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument(
+ "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
+ )
+ parser.add_argument(
+ "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
+ )
+ parser.add_argument(
+ "--field",
+ "-f",
+ help="The target columns of the input text as 1-based integer. e.g 2-",
+ )
+ parser.add_argument(
+ "--token_type",
+ "-t",
+ default="char",
+ choices=["char", "bpe", "word", "phn"],
+ help="Token type",
+ )
+ parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
+ parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
+ parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--remove_non_linguistic_symbols",
+ type=str2bool,
+ default=False,
+ help="Remove non-language-symbols from tokens",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+
+ group = parser.add_argument_group("write_vocabulary mode related")
+ group.add_argument(
+ "--write_vocabulary",
+ type=str2bool,
+ default=False,
+ help="Write tokens list instead of tokenized text per line",
+ )
+ group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
+ group.add_argument(
+ "--cutoff",
+ default=0,
+ type=int,
+ help="cut-off frequency used for write-vocabulary mode",
+ )
+ group.add_argument(
+ "--add_symbol",
+ type=str,
+ default=[],
+ action="append",
+ help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ tokenize(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 10fbccb..79540c1 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -58,6 +58,15 @@
continue
return out_txt.strip().split()
+def seg_tokenize_wo_pattern(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
def framing(
x,
@@ -372,6 +381,70 @@
data = self._text_process(data)
return data
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train,
+ token_type,
+ token_list,
+ bpemodel,
+ text_cleaner,
+ g2p_type,
+ unk_symbol,
+ space_symbol,
+ non_linguistic_symbols,
+ delimiter,
+ rir_scp,
+ rir_apply_prob,
+ noise_scp,
+ noise_apply_prob,
+ noise_db_range,
+ speech_volume_normalize,
+ speech_name,
+ text_name,
+ split_with_space,
+ seg_dict_file,
+ )
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
class CommonPreprocessor_multi(AbsPreprocessor):
def __init__(
diff --git a/funasr/export/README.md b/funasr/export/README.md
index 9740f23..0ecf272 100644
--- a/funasr/export/README.md
+++ b/funasr/export/README.md
@@ -16,17 +16,13 @@
output_dir = "../export" # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
-export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
-from funasr.export.export_model import ASRModelExportParaformer
-
-output_dir = "../export" # onnx/torchscripts model save path
-export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
-export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
## Export torchscripts format model
@@ -36,15 +32,12 @@
output_dir = "../export" # onnx/torchscripts model save path
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
-export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
Export model from local path
```python
-from funasr.export.export_model import ASRModelExportParaformer
-output_dir = "../export" # onnx/torchscripts model save path
-export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
-export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
```
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 9a599eb..239fd6c 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -1,3 +1,4 @@
+import json
from typing import Union, Dict
from pathlib import Path
from typeguard import check_argument_types
@@ -8,14 +9,15 @@
from funasr.bin.asr_inference_paraformer import Speech2Text
from funasr.export.models import get_model
-
-
+import numpy as np
+import random
class ASRModelExportParaformer:
def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
assert check_argument_types()
+ self.set_all_random_seed(0)
if cache_dir is None:
- cache_dir = Path.home() / "cache" / "export"
+ cache_dir = Path.home() / ".cache" / "export"
self.cache_dir = Path(cache_dir)
self.export_config = dict(
@@ -24,8 +26,9 @@
)
logging.info("output dir: {}".format(self.cache_dir))
self.onnx = onnx
+
- def export(
+ def _export(
self,
model: Speech2Text,
tag_name: str = None,
@@ -60,38 +63,38 @@
model_script = torch.jit.trace(model, dummy_input)
model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
- def export_from_modelscope(
- self,
- tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- ):
+ def set_all_random_seed(self, seed: int):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+ def export(self,
+ tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+ mode: str = 'paraformer',
+ ):
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- from modelscope.hub.snapshot_download import snapshot_download
-
- model_dir = snapshot_download(tag_name, cache_dir=self.cache_dir)
- asr_train_config = os.path.join(model_dir, 'config.yaml')
- asr_model_file = os.path.join(model_dir, 'model.pb')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
- model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, 'cpu'
- )
- self.export(model, tag_name)
-
- def export_from_local(
- self,
- tag_name: str = '/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
- ):
-
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-
model_dir = tag_name
+ if model_dir.startswith('damo/'):
+ from modelscope.hub.snapshot_download import snapshot_download
+ model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
asr_train_config = os.path.join(model_dir, 'config.yaml')
asr_model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
+ json_file = os.path.join(model_dir, 'configuration.json')
+ if mode is None:
+ import json
+ with open(json_file, 'r') as f:
+ config_data = json.load(f)
+ mode = config_data['model']['model_config']['mode']
+ if mode == 'paraformer':
+ from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+ elif mode == 'uniasr':
+ from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+
model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, 'cpu'
)
- self.export(model, tag_name)
+ self._export(model, tag_name)
+
def _export_onnx(self, model, verbose, path, enc_size=None):
if enc_size:
@@ -116,5 +119,5 @@
if __name__ == '__main__':
output_dir = "../export"
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
- export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
- # export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
\ No newline at end of file
+ export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+ # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
\ No newline at end of file
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index b21b080..ca2c813 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,42 +1,3 @@
-# from .ctc import CTC
-# from .joint_network import JointNetwork
-#
-# # encoder
-# from espnet2.asr.encoder.rnn_encoder import RNNEncoder as espnetRNNEncoder
-# from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder as espnetVGGRNNEncoder
-# from espnet2.asr.encoder.contextual_block_transformer_encoder import ContextualBlockTransformerEncoder as espnetContextualTransformer
-# from espnet2.asr.encoder.contextual_block_conformer_encoder import ContextualBlockConformerEncoder as espnetContextualConformer
-# from espnet2.asr.encoder.transformer_encoder import TransformerEncoder as espnetTransformerEncoder
-# from espnet2.asr.encoder.conformer_encoder import ConformerEncoder as espnetConformerEncoder
-# from funasr.export.models.encoder.rnn import RNNEncoder
-# from funasr.export.models.encoders import TransformerEncoder
-# from funasr.export.models.encoders import ConformerEncoder
-# from funasr.export.models.encoder.contextual_block_xformer import ContextualBlockXformerEncoder
-#
-# # decoder
-# from espnet2.asr.decoder.rnn_decoder import RNNDecoder as espnetRNNDecoder
-# from espnet2.asr.transducer.transducer_decoder import TransducerDecoder as espnetTransducerDecoder
-# from funasr.export.models.decoder.rnn import (
-# RNNDecoder
-# )
-# from funasr.export.models.decoders import XformerDecoder
-# from funasr.export.models.decoders import TransducerDecoder
-#
-# # lm
-# from espnet2.lm.seq_rnn_lm import SequentialRNNLM as espnetSequentialRNNLM
-# from espnet2.lm.transformer_lm import TransformerLM as espnetTransformerLM
-# from .language_models.seq_rnn import SequentialRNNLM
-# from .language_models.transformer import TransformerLM
-#
-# # frontend
-# from espnet2.asr.frontend.s3prl import S3prlFrontend as espnetS3PRLModel
-# from .frontends.s3prl import S3PRLModel
-#
-# from espnet2.asr.encoder.sanm_encoder import SANMEncoder_tf, SANMEncoderChunkOpt_tf
-# from espnet_onnx.export.asr.models.encoders.transformer_sanm import TransformerEncoderSANM_tf
-# from espnet2.asr.decoder.transformer_decoder import FsmnDecoderSCAMAOpt_tf
-# from funasr.export.models.decoders import XformerDecoderSANM
-
from funasr.models.e2e_asr_paraformer import Paraformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
@@ -45,47 +6,4 @@
if isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
else:
- raise "The model is not exist!"
-
-
-# def get_encoder(model, frontend, preencoder, predictor=None, export_config=None):
-# if isinstance(model, espnetRNNEncoder) or isinstance(model, espnetVGGRNNEncoder):
-# return RNNEncoder(model, frontend, preencoder, **export_config)
-# elif isinstance(model, espnetContextualTransformer) or isinstance(model, espnetContextualConformer):
-# return ContextualBlockXformerEncoder(model, **export_config)
-# elif isinstance(model, espnetTransformerEncoder):
-# return TransformerEncoder(model, frontend, preencoder, **export_config)
-# elif isinstance(model, espnetConformerEncoder):
-# return ConformerEncoder(model, frontend, preencoder, **export_config)
-# elif isinstance(model, SANMEncoder_tf) or isinstance(model, SANMEncoderChunkOpt_tf):
-# return TransformerEncoderSANM_tf(model, frontend, preencoder, predictor, **export_config)
-# else:
-# raise "The model is not exist!"
-
-
-#
-# def get_decoder(model, export_config):
-# if isinstance(model, espnetRNNDecoder):
-# return RNNDecoder(model, **export_config)
-# elif isinstance(model, espnetTransducerDecoder):
-# return TransducerDecoder(model, **export_config)
-# elif isinstance(model, FsmnDecoderSCAMAOpt_tf):
-# return XformerDecoderSANM(model, **export_config)
-# else:
-# return XformerDecoder(model, **export_config)
-#
-#
-# def get_lm(model, export_config):
-# if isinstance(model, espnetSequentialRNNLM):
-# return SequentialRNNLM(model, **export_config)
-# elif isinstance(model, espnetTransformerLM):
-# return TransformerLM(model, **export_config)
-#
-#
-# def get_frontend_models(model, export_config):
-# if isinstance(model, espnetS3PRLModel):
-# return S3PRLModel(model, **export_config)
-# else:
-# return None
-#
-
\ No newline at end of file
+ raise "The model is not exist!"
\ No newline at end of file
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
index 4fc3b49..db11b67 100644
--- a/funasr/lm/espnet_model.py
+++ b/funasr/lm/espnet_model.py
@@ -46,10 +46,10 @@
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
- x = F.pad(text, [1, 0], "constant", self.eos)
+ x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
- t[i, l] = self.sos
+ t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
new file mode 100644
index 0000000..32f550a
--- /dev/null
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -0,0 +1,776 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.modules.streaming_utils import utils as myutils
+from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from typeguard import check_argument_types
+
+from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.modules.embedding import PositionalEncoding
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.modules.repeat import repeat
+from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+
+
+class ContextualDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(ContextualDecoderLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ if self_attn is not None:
+ self.norm2 = LayerNorm(size)
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+ # tgt = self.dropout(tgt)
+ if isinstance(tgt, Tuple):
+ tgt, _ = tgt
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ if self.training:
+ cache = None
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + self.dropout(x)
+ x_self_attn = x
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = self.src_attn(x, memory, memory_mask)
+ x_src_attn = x
+
+ x = residual + self.dropout(x)
+ return x, tgt_mask, x_self_attn, x_src_attn
+
+
+class ContexutalBiasDecoder(nn.Module):
+ def __init__(
+ self,
+ size,
+ src_attn,
+ dropout_rate,
+ normalize_before=True,
+ ):
+ """Construct an DecoderLayer object."""
+ super(ContexutalBiasDecoder, self).__init__()
+ self.size = size
+ self.src_attn = src_attn
+ if src_attn is not None:
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ x = tgt
+ if self.src_attn is not None:
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = self.dropout(self.src_attn(x, memory, memory_mask))
+ return x, tgt_mask, memory, memory_mask, cache
+
+
+class ContextualParaformerDecoder(ParaformerSANMDecoder):
+ """
+ author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2006.01713
+ """
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = 0,
+ ):
+ assert check_argument_types()
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+ if input_layer == 'none':
+ self.embed = None
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ # pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ self.att_layer_num = att_layer_num
+ self.num_blocks = num_blocks
+ if sanm_shfit is None:
+ sanm_shfit = (kernel_size - 1) // 2
+ self.decoders = repeat(
+ att_layer_num - 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ self.dropout = nn.Dropout(dropout_rate)
+ self.bias_decoder = ContexutalBiasDecoder(
+ size=attention_dim,
+ src_attn=MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ dropout_rate=dropout_rate,
+ normalize_before=True,
+ )
+ self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
+ self.last_decoder = ContextualDecoderLayer(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ ),
+ MultiHeadedAttentionCrossAtt(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ if num_blocks - att_layer_num <= 0:
+ self.decoders2 = None
+ else:
+ self.decoders2 = repeat(
+ num_blocks - att_layer_num,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ MultiHeadedAttentionSANMDecoder(
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ ),
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ self.decoders3 = repeat(
+ 1,
+ lambda lnum: DecoderLayerSANM(
+ attention_dim,
+ None,
+ None,
+ PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ contextual_info: torch.Tensor,
+ return_hidden: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ _, _, x_self_attn, x_src_attn = self.last_decoder(
+ x, tgt_mask, memory, memory_mask
+ )
+
+ # contextual paraformer related
+ contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
+ contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
+ cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
+
+ if self.bias_output is not None:
+ x = torch.cat([x_src_attn, cx], dim=2)
+ x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
+ x = x_self_attn + self.dropout(x)
+
+ if self.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ olens = tgt_mask.sum(1)
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(x)
+ return x, olens
+
+ def gen_tf2torch_map_dict(self):
+
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+
+ ## decoder
+ # ffn
+ "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # fsmn
+ "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+ tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 2, 0),
+ }, # (256,1,31),(1,31,256,1)
+ # src att
+ "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # dnn
+ "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # embed_concat_ffn
+ "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # out norm
+ "{}.after_norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.after_norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+
+ # in embed
+ "{}.embed.0.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (4235,256),(4235,256)
+
+ # out layer
+ "{}.output_layer.weight".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+ "squeeze": [None, None],
+ "transpose": [(1, 0), None],
+ }, # (4235,256),(256,4235)
+ "{}.output_layer.bias".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+ "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+ "squeeze": [None, None],
+ "transpose": [None, None],
+ }, # (4235,),(4235,)
+
+ ## clas decoder
+ # src att
+ "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # dnn
+ "{}.bias_output.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ }, # (1024,256),(1,256,1024)
+
+ }
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+ map_dict = self.gen_tf2torch_map_dict()
+ var_dict_torch_update = dict()
+ decoder_layeridx_sets = set()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == self.tf2torch_tensor_name_prefix_torch:
+ if names[1] == "decoders":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[1] == "last_decoder":
+ layeridx = 15
+ name_q = name.replace("last_decoder", "decoders.layeridx")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+
+ elif names[1] == "decoders2":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ name_q = name_q.replace("decoders2", "decoders")
+ layeridx_bias = len(decoder_layeridx_sets)
+
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "decoders3":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[1] == "bias_decoder":
+ name_q = name
+
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+
+ elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+ name_tf = map_dict[name]["name"]
+ if isinstance(name_tf, list):
+ idx_list = 0
+ if name_tf[idx_list] in var_dict_tf.keys():
+ pass
+ else:
+ idx_list = 1
+ data_tf = var_dict_tf[name_tf[idx_list]]
+ if map_dict[name]["squeeze"][idx_list] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+ if map_dict[name]["transpose"][idx_list] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+ name_tf[idx_list],
+ var_dict_tf[name_tf[
+ idx_list]].shape))
+
+ else:
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "after_norm":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "embed_concat_ffn":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 7596896..5786bc4 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -8,6 +8,8 @@
from typing import Union
import torch
+import random
+import numpy as np
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
@@ -24,7 +26,7 @@
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -824,7 +826,10 @@
class BiCifParaformer(Paraformer):
- """CTC-attention hybrid Encoder-Decoder model"""
+ """
+ Paraformer model with an extra cif predictor
+ to conduct accurate timestamp prediction
+ """
def __init__(
self,
@@ -891,7 +896,7 @@
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
- def _calc_att_loss(
+ def _calc_pre2_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
@@ -903,47 +908,12 @@
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length2)
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2
+ return loss_pre2
def calc_predictor(self, encoder_out, encoder_out_lens):
@@ -956,10 +926,154 @@
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, None, encoder_out_mask, token_num=token_num,
- ignore_id=self.ignore_id)
- import pdb; pdb.set_trace()
+ ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ self.step_cur += 1
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ stats = dict()
+
+ loss_pre2 = self._calc_pre2_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ loss = loss_pre2
+
+ stats["loss_pre2"] = loss_pre2.detach().cpu()
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+class ContextualParaformer(Paraformer):
+ """
+ Paraformer model with contextual hotword
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ min_hw_length: int = 2,
+ max_hw_length: int = 4,
+ sample_rate: float = 0.6,
+ batch_rate: float = 0.5,
+ double_rate: float = -1.0,
+ target_buffer_length: int = -1,
+ inner_dim: int = 256,
+ bias_encoder_type: str = 'lstm',
+ label_bracket: bool = False,
+ ):
+ assert check_argument_types()
+ assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+ assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+ super().__init__(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
+ )
+
+ if bias_encoder_type == 'lstm':
+ logging.warning("enable bias encoder sampling and contextual training")
+ self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
+ self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
+ else:
+ logging.error("Unsupport bias encoder type")
+
+ self.min_hw_length = min_hw_length
+ self.max_hw_length = max_hw_length
+ self.sample_rate = sample_rate
+ self.batch_rate = batch_rate
+ self.target_buffer_length = target_buffer_length
+ self.double_rate = double_rate
+
+ if self.target_buffer_length > 0:
+ self.hotword_buffer = None
+ self.length_record = []
+ self.current_buffer_length = 0
def forward(
self,
@@ -1038,17 +1152,17 @@
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 = self._calc_att_loss(
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight
+ loss = loss_att + loss_pre * self.predictor_weight
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1056,10 +1170,292 @@
stats["cer"] = cer_att
stats["wer"] = wer_att
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss_pre2"] = loss_pre2.detach().cpu() if loss_pre is not None else None
stats["loss"] = torch.clone(loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
\ No newline at end of file
+ return loss, stats, weight
+
+ def _sample_hot_word(self, ys_pad, ys_pad_lens):
+ hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
+ hw_lengths = [0] # this length is actually for indice, so -1
+ for i, length in enumerate(ys_pad_lens):
+ if length < 2:
+ continue
+ if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
+ # sample double hotword
+ _max_hw_length = min(self.max_hw_length, length // 2)
+ # first hotword
+ start1 = random.randint(0, length // 3)
+ end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
+ hw_tokens1 = ys_pad[i][start1:end1 + 1]
+ hw_lengths.append(len(hw_tokens1) - 1)
+ hw_list.append(hw_tokens1)
+ # second hotword
+ start2 = random.randint(end1 + 1, length - self.min_hw_length)
+ end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
+ min(length - 1, start2 + self.max_hw_length - 1))
+ hw_tokens2 = ys_pad[i][start2:end2 + 1]
+ hw_lengths.append(len(hw_tokens2) - 1)
+ hw_list.append(hw_tokens2)
+ continue
+ if random.random() < self.sample_rate:
+ if length == 2:
+ hw_tokens = ys_pad[i][:2]
+ hw_lengths.append(1)
+ hw_list.append(hw_tokens)
+ else:
+ start = random.randint(0, length - self.min_hw_length)
+ end = random.randint(min(length - 1, start + self.min_hw_length - 1),
+ min(length - 1, start + self.max_hw_length - 1)) + 1
+ # print(start, end)
+ hw_tokens = ys_pad[i][start:end]
+ hw_lengths.append(len(hw_tokens) - 1)
+ hw_list.append(hw_tokens)
+ # padding
+ hw_list_pad = pad_list(hw_list, 0)
+ hw_embed = self.decoder.embed(hw_list_pad)
+ hw_embed, (_, _) = self.bias_encoder(hw_embed)
+ _ind = np.arange(0, len(hw_list)).tolist()
+ # update self.hotword_buffer, throw a part if oversize
+ selected = hw_embed[_ind, hw_lengths]
+ if self.target_buffer_length > 0:
+ _b = selected.shape[0]
+ if self.hotword_buffer is None:
+ self.hotword_buffer = selected
+ self.length_record.append(selected.shape[0])
+ self.current_buffer_length = _b
+ elif self.current_buffer_length + _b < self.target_buffer_length:
+ self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
+ self.current_buffer_length += _b
+ selected = self.hotword_buffer
+ else:
+ self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
+ random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
+ self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
+ selected = self.hotword_buffer
+ self.current_buffer_length = selected.shape[0]
+ return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+ def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad)
+ with torch.no_grad():
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_pad_lens = ys_pad_lens + self.predictor_bias
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
+
+ # sample hot word
+ contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
+
+ # 0. sampler
+ decoder_out_1st = None
+ if self.sampling_ratio > 0.0:
+ if self.step_cur < 2:
+ logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds, contextual_info)
+ else:
+ if self.step_cur < 2:
+ logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+ sematic_embeds = pre_acoustic_embeds
+
+ # 1. Forward decoder
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+ if decoder_out_1st is None:
+ decoder_out_1st = decoder_out
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_pad)
+ acc_att = th_accuracy(
+ decoder_out_1st.view(-1, self.vocab_size),
+ ys_pad,
+ ignore_label=self.ignore_id,
+ )
+ loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out_1st.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+ if hw_list is None:
+ # default hotword list
+ hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
+ hw_list_pad = pad_list(hw_list, 0)
+ hw_embed = self.bias_embed(hw_list_pad)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
+ else:
+ hw_lengths = [len(i) for i in hw_list]
+ hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+ hw_embed = self.bias_embed(hw_list_pad)
+ hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+ enforce_sorted=False)
+ _, (h_n, _) = self.bias_encoder(hw_embed)
+ # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
+ contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
+ def gen_clas_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = "bias_encoder"
+ tensor_name_prefix_tf = "seq2seq/clas_charrnn"
+
+ tensor_name_prefix_torch_emb = "bias_embed"
+ tensor_name_prefix_tf_emb = "seq2seq"
+
+ map_dict_local = {
+ # in lstm
+ "{}.weight_ih_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (1, 0),
+ "slice": (0, 512),
+ "unit_k": 512,
+ }, # (1024, 2048),(2048,512)
+ "{}.weight_hh_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (1, 0),
+ "slice": (512, 1024),
+ "unit_k": 512,
+ }, # (1024, 2048),(2048,512)
+ "{}.bias_ih_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ "scale": 0.5,
+ "unit_b": 512,
+ }, # (2048,),(2048,)
+ "{}.bias_hh_l0".format(tensor_name_prefix_torch):
+ {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ "scale": 0.5,
+ "unit_b": 512,
+ }, # (2048,),(2048,)
+
+ # in embed
+ "{}.weight".format(tensor_name_prefix_torch_emb):
+ {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
+ "squeeze": None,
+ "transpose": None,
+ }, # (4235,256),(4235,256)
+ }
+ return map_dict_local
+
+ def clas_convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch):
+ map_dict = self.gen_clas_tf2torch_map_dict()
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == "bias_encoder":
+ name_q = name
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q].get("unit_k") is not None:
+ dim = map_dict[name_q]["unit_k"]
+ i = data_tf[:, 0:dim].copy()
+ f = data_tf[:, dim:2 * dim].copy()
+ o = data_tf[:, 2 * dim:3 * dim].copy()
+ g = data_tf[:, 3 * dim:4 * dim].copy()
+ data_tf = np.concatenate([i, o, f, g], axis=1)
+ if map_dict[name_q].get("unit_b") is not None:
+ dim = map_dict[name_q]["unit_b"]
+ i = data_tf[0:dim].copy()
+ f = data_tf[dim:2 * dim].copy()
+ o = data_tf[2 * dim:3 * dim].copy()
+ g = data_tf[3 * dim:4 * dim].copy()
+ data_tf = np.concatenate([i, o, f, g], axis=0)
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q].get("slice") is not None:
+ data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
+ if map_dict[name_q].get("scale") is not None:
+ data_tf = data_tf * map_dict[name_q]["scale"]
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+ elif names[0] == "bias_embed":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
new file mode 100644
index 0000000..d29ffe5
--- /dev/null
+++ b/funasr/models/e2e_diar_sond.py
@@ -0,0 +1,402 @@
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from itertools import permutations
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from typeguard import check_argument_types
+
+from funasr.modules.nets_utils import to_device
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class DiarSondModel(AbsESPnetModel):
+ """Speaker overlap-aware neural diarization model
+ reference: https://arxiv.org/abs/2211.10243
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: AbsEncoder,
+ speaker_encoder: AbsEncoder,
+ ci_scorer: torch.nn.Module,
+ cd_scorer: torch.nn.Module,
+ decoder: torch.nn.Module,
+ token_list: list,
+ lsm_weight: float = 0.1,
+ length_normalized_loss: bool = False,
+ max_spk_num: int = 16,
+ label_aggregator: Optional[torch.nn.Module] = None,
+ normlize_speech_speaker: bool = False,
+ ):
+ assert check_argument_types()
+
+ super().__init__()
+
+ self.encoder = encoder
+ self.speaker_encoder = speaker_encoder
+ self.ci_scorer = ci_scorer
+ self.cd_scorer = cd_scorer
+ self.normalize = normalize
+ self.frontend = frontend
+ self.specaug = specaug
+ self.label_aggregator = label_aggregator
+ self.decoder = decoder
+ self.token_list = token_list
+ self.max_spk_num = max_spk_num
+ self.normalize_speech_speaker = normlize_speech_speaker
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor = None,
+ profile: torch.Tensor = None,
+ profile_lengths: torch.Tensor = None,
+ spk_labels: torch.Tensor = None,
+ spk_labels_lengths: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, samples)
+ speech_lengths: (Batch,) default None for chunk interator,
+ because the chunk-iterator does not
+ have the speech_lengths returned.
+ see in
+ espnet2/iterators/chunk_iter_factory.py
+ profile: (Batch, N_spk, dim)
+ profile_lengths: (Batch,)
+ spk_labels: (Batch, )
+ """
+ assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
+ batch_size = speech.shape[0]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ if self.attractor is None:
+ # 2a. Decoder (baiscally a predction layer after encoder_out)
+ pred = self.decoder(encoder_out, encoder_out_lens)
+ else:
+ # 2b. Encoder Decoder Attractors
+ # Shuffle the chronological order of encoder_out, then calculate attractor
+ encoder_out_shuffled = encoder_out.clone()
+ for i in range(len(encoder_out_lens)):
+ encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
+ i, torch.randperm(encoder_out_lens[i]), :
+ ]
+ attractor, att_prob = self.attractor(
+ encoder_out_shuffled,
+ encoder_out_lens,
+ to_device(
+ self,
+ torch.zeros(
+ encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
+ ),
+ ),
+ )
+ # Remove the final attractor which does not correspond to a speaker
+ # Then multiply the attractors and encoder_out
+ pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
+ # 3. Aggregate time-domain labels
+ if self.label_aggregator is not None:
+ spk_labels, spk_labels_lengths = self.label_aggregator(
+ spk_labels, spk_labels_lengths
+ )
+
+ # If encoder uses conv* as input_layer (i.e., subsampling),
+ # the sequence length of 'pred' might be slighly less than the
+ # length of 'spk_labels'. Here we force them to be equal.
+ length_diff_tolerance = 2
+ length_diff = spk_labels.shape[1] - pred.shape[1]
+ if length_diff > 0 and length_diff <= length_diff_tolerance:
+ spk_labels = spk_labels[:, 0 : pred.shape[1], :]
+
+ if self.attractor is None:
+ loss_pit, loss_att = None, None
+ loss, perm_idx, perm_list, label_perm = self.pit_loss(
+ pred, spk_labels, encoder_out_lens
+ )
+ else:
+ loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
+ pred, spk_labels, encoder_out_lens
+ )
+ loss_att = self.attractor_loss(att_prob, spk_labels)
+ loss = loss_pit + self.attractor_weight * loss_att
+ (
+ correct,
+ num_frames,
+ speech_scored,
+ speech_miss,
+ speech_falarm,
+ speaker_scored,
+ speaker_miss,
+ speaker_falarm,
+ speaker_error,
+ ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
+
+ if speech_scored > 0 and num_frames > 0:
+ sad_mr, sad_fr, mi, fa, cf, acc, der = (
+ speech_miss / speech_scored,
+ speech_falarm / speech_scored,
+ speaker_miss / speaker_scored,
+ speaker_falarm / speaker_scored,
+ speaker_error / speaker_scored,
+ correct / num_frames,
+ (speaker_miss + speaker_falarm + speaker_error) / speaker_scored,
+ )
+ else:
+ sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_att=loss_att.detach() if loss_att is not None else None,
+ loss_pit=loss_pit.detach() if loss_pit is not None else None,
+ sad_mr=sad_mr,
+ sad_fr=sad_fr,
+ mi=mi,
+ fa=fa,
+ cf=cf,
+ acc=acc,
+ der=der,
+ )
+
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ spk_labels: torch.Tensor = None,
+ spk_labels_lengths: torch.Tensor = None,
+ ) -> Dict[str, torch.Tensor]:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode_speaker(
+ self,
+ profile: torch.Tensor,
+ profile_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ with autocast(False):
+ if profile.shape[1] < self.max_spk_num:
+ profile = F.pad(profile, [0, 0, 0, self.max_spk_num-profile.shape[1], 0, 0], "constant", 0.0)
+ profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()
+ profile = F.normalize(profile, dim=2)
+ if self.speaker_encoder is not None:
+ profile = self.speaker_encoder(profile, profile_lengths)[0]
+ return profile * profile_mask, profile_lengths
+ else:
+ return profile, profile_lengths
+
+ def encode_speech(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.encoder is not None:
+ speech, speech_lengths = self.encode(speech, speech_lengths)
+ speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
+ speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
+ return speech * speech_mask, speech_lengths
+ else:
+ return speech, speech_lengths
+
+ @staticmethod
+ def concate_speech_ivc(
+ speech: torch.Tensor,
+ ivc: torch.Tensor
+ ) -> torch.Tensor:
+ nn, tt = ivc.shape[1], speech.shape[1]
+ speech = speech.unsqueeze(dim=1) # B x 1 x T x D
+ speech = speech.expand(-1, nn, -1, -1) # B x N x T x D
+ ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D
+ ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D
+ sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D
+ return sd_in
+
+ def calc_similarity(
+ self,
+ speech_encoder_outputs: torch.Tensor,
+ speaker_encoder_outputs: torch.Tensor,
+ seq_len: torch.Tensor = None,
+ spk_len: torch.Tensor = None,
+ ) -> torch.Tensor:
+ bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
+ d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
+ if self.normalize_speech_speaker:
+ speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2)
+ speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2)
+ ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs)
+ ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk])
+ ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num)
+ ge_len = torch.reshape(ge_len, [bb * self.max_spk_num])
+ cd_simi = self.cd_scorer(ge_in, ge_len)[0]
+ cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
+ cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
+
+ if isinstance(self.ci_scorer, AbsEncoder):
+ ci_simi = self.ci_scorer(ge_in, ge_len)[0]
+ else:
+ ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
+ simi = torch.cat([cd_simi, ci_simi], dim=2)
+
+ return simi
+
+ def post_net_forward(self, simi, seq_len):
+ logits = self.decoder(simi, seq_len)[0]
+
+ return logits
+
+ def prediction_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lengths: torch.Tensor,
+ ) -> torch.Tensor:
+ # speech encoding
+ speech, speech_lengths = self.encode_speech(speech, speech_lengths)
+ # speaker encoding
+ profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
+ # calculating similarity
+ similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
+ # post net forward
+ logits = self.post_net_forward(similarity, speech_lengths)
+
+ return logits
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch,)
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim)
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = speech.shape[0]
+ speech_lengths = (
+ speech_lengths
+ if speech_lengths is not None
+ else torch.ones(batch_size).int() * speech.shape[1]
+ )
+
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ @staticmethod
+ def calc_diarization_error(pred, label, length):
+ # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND
+
+ (batch_size, max_len, num_output) = label.size()
+ # mask the padding part
+ mask = np.zeros((batch_size, max_len, num_output))
+ for i in range(batch_size):
+ mask[i, : length[i], :] = 1
+
+ # pred and label have the shape (batch_size, max_len, num_output)
+ label_np = label.data.cpu().numpy().astype(int)
+ pred_np = (pred.data.cpu().numpy() > 0).astype(int)
+ label_np = label_np * mask
+ pred_np = pred_np * mask
+ length = length.data.cpu().numpy()
+
+ # compute speech activity detection error
+ n_ref = np.sum(label_np, axis=2)
+ n_sys = np.sum(pred_np, axis=2)
+ speech_scored = float(np.sum(n_ref > 0))
+ speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0)))
+ speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0)))
+
+ # compute speaker diarization error
+ speaker_scored = float(np.sum(n_ref))
+ speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0)))
+ speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0)))
+ n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2)
+ speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map))
+ correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output)
+ num_frames = np.sum(length)
+ return (
+ correct,
+ num_frames,
+ speech_scored,
+ speech_miss,
+ speech_falarm,
+ speaker_scored,
+ speaker_miss,
+ speaker_falarm,
+ speaker_error,
+ )
diff --git a/funasr/models/encoder/opennmt_encoders/__init__.py b/funasr/models/encoder/opennmt_encoders/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/encoder/opennmt_encoders/__init__.py
diff --git a/funasr/models/encoder/opennmt_encoders/ci_scorers.py b/funasr/models/encoder/opennmt_encoders/ci_scorers.py
new file mode 100644
index 0000000..50056ee
--- /dev/null
+++ b/funasr/models/encoder/opennmt_encoders/ci_scorers.py
@@ -0,0 +1,38 @@
+import torch
+from torch.nn import functional as F
+
+
+class DotScorer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ spk_emb: torch.Tensor,
+ ):
+ # xs_pad: B, T, D
+ # spk_emb: B, N, D
+ scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
+ return scores
+
+ def convert_tf2torch(self, var_dict_tf, var_dict_torch):
+ return {}
+
+
+class CosScorer(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ spk_emb: torch.Tensor,
+ ):
+ # xs_pad: B, T, D
+ # spk_emb: B, N, D
+ scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
+ return scores
+
+ def convert_tf2torch(self, var_dict_tf, var_dict_torch):
+ return {}
diff --git a/funasr/models/encoder/opennmt_encoders/conv_encoder.py b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
new file mode 100644
index 0000000..4096743
--- /dev/null
+++ b/funasr/models/encoder/opennmt_encoders/conv_encoder.py
@@ -0,0 +1,277 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from typeguard import check_argument_types
+import numpy as np
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.layer_norm import LayerNorm
+from funasr.models.encoder.abs_encoder import AbsEncoder
+import math
+from funasr.modules.repeat import repeat
+
+
+class EncoderLayer(nn.Module):
+ def __init__(
+ self,
+ input_units,
+ num_units,
+ kernel_size=3,
+ activation="tanh",
+ stride=1,
+ include_batch_norm=False,
+ residual=False
+ ):
+ super().__init__()
+ left_padding = math.ceil((kernel_size - stride) / 2)
+ right_padding = kernel_size - stride - left_padding
+ self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+ self.conv1d = nn.Conv1d(
+ input_units,
+ num_units,
+ kernel_size,
+ stride,
+ )
+ self.activation = self.get_activation(activation)
+ if include_batch_norm:
+ self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
+ self.residual = residual
+ self.include_batch_norm = include_batch_norm
+ self.input_units = input_units
+ self.num_units = num_units
+ self.stride = stride
+
+ @staticmethod
+ def get_activation(activation):
+ if activation == "tanh":
+ return nn.Tanh()
+ else:
+ return nn.ReLU()
+
+ def forward(self, xs_pad, ilens=None):
+ outputs = self.conv1d(self.conv_padding(xs_pad))
+ if self.residual and self.stride == 1 and self.input_units == self.num_units:
+ outputs = outputs + xs_pad
+
+ if self.include_batch_norm:
+ outputs = self.bn(outputs)
+
+ # add parenthesis for repeat module
+ return self.activation(outputs), ilens
+
+
+class ConvEncoder(AbsEncoder):
+ """
+ author: Speech Lab, Alibaba Group, China
+ Convolution encoder in OpenNMT framework
+ """
+
+ def __init__(
+ self,
+ num_layers,
+ input_units,
+ num_units,
+ kernel_size=3,
+ dropout_rate=0.3,
+ position_encoder=None,
+ activation='tanh',
+ auxiliary_states=True,
+ out_units=None,
+ out_norm=False,
+ out_residual=False,
+ include_batchnorm=False,
+ regularization_weight=0.0,
+ stride=1,
+ tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
+ tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self._output_size = num_units
+
+ self.num_layers = num_layers
+ self.input_units = input_units
+ self.num_units = num_units
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+ self.position_encoder = position_encoder
+ self.out_units = out_units
+ self.auxiliary_states = auxiliary_states
+ self.out_norm = out_norm
+ self.activation = activation
+ self.out_residual = out_residual
+ self.include_batch_norm = include_batchnorm
+ self.regularization_weight = regularization_weight
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ if isinstance(stride, int):
+ self.stride = [stride] * self.num_layers
+ else:
+ self.stride = stride
+ self.downsample_rate = 1
+ for s in self.stride:
+ self.downsample_rate *= s
+
+ self.dropout = nn.Dropout(dropout_rate)
+ self.cnn_a = repeat(
+ self.num_layers,
+ lambda lnum: EncoderLayer(
+ input_units if lnum == 0 else num_units,
+ num_units,
+ kernel_size,
+ activation,
+ self.stride[lnum],
+ include_batchnorm,
+ residual=True if lnum > 0 else False
+ )
+ )
+
+ if self.out_units is not None:
+ left_padding = math.ceil((kernel_size - stride) / 2)
+ right_padding = kernel_size - stride - left_padding
+ self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+ self.conv_out = nn.Conv1d(
+ num_units,
+ num_units,
+ kernel_size,
+ )
+
+ if self.out_norm:
+ self.after_norm = LayerNorm(num_units)
+
+ def output_size(self) -> int:
+ return self.num_units
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+
+ inputs = xs_pad
+ if self.position_encoder is not None:
+ inputs = self.position_encoder(inputs)
+
+ if self.dropout_rate > 0:
+ inputs = self.dropout(inputs)
+
+ outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
+
+ if self.out_units is not None:
+ outputs = self.conv_out(self.out_padding(outputs))
+
+ outputs = outputs.transpose(1, 2)
+ if self.out_norm:
+ outputs = self.after_norm(outputs)
+
+ if self.out_residual:
+ outputs = outputs + inputs
+
+ return outputs, ilens, None
+
+ def gen_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+ # torch: conv1d.weight in "out_channel in_channel kernel_size"
+ # tf : conv1d.weight in "kernel_size in_channel out_channel"
+ # torch: linear.weight in "out_channel in_channel"
+ # tf : dense.weight in "in_channel out_channel"
+ "{}.cnn_a.0.conv1d.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cnn_a/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+ "{}.cnn_a.0.conv1d.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cnn_a/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+
+ "{}.cnn_a.layeridx.conv1d.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cnn_a/conv1d_layeridx/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+ "{}.cnn_a.layeridx.conv1d.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cnn_a/conv1d_layeridx/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ }
+ if self.out_units is not None:
+ # add output layer
+ map_dict_local.update({
+ "{}.conv_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cnn_a/conv1d_{}/kernel".format(tensor_name_prefix_tf, self.num_layers),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ }, # tf: (1, 256, 256) -> torch: (256, 256, 1)
+ "{}.conv_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cnn_a/conv1d_{}/bias".format(tensor_name_prefix_tf, self.num_layers),
+ "squeeze": None,
+ "transpose": None,
+ }, # tf: (256,) -> torch: (256,)
+ })
+
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ if name.startswith(self.tf2torch_tensor_name_prefix_torch):
+ # process special (first and last) layers
+ if name in map_dict:
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ # process general layers
+ else:
+ # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
+ names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ else:
+ logging.warning("{} is missed from tf checkpoint".format(name))
+
+ return var_dict_torch_update
+
diff --git a/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py b/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
new file mode 100644
index 0000000..e41b2aa
--- /dev/null
+++ b/funasr/models/encoder/opennmt_encoders/fsmn_encoder.py
@@ -0,0 +1,335 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from typeguard import check_argument_types
+import numpy as np
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.layer_norm import LayerNorm
+from funasr.models.encoder.abs_encoder import AbsEncoder
+import math
+from funasr.modules.repeat import repeat
+from funasr.modules.multi_layer_conv import FsmnFeedForward
+
+
+class FsmnBlock(torch.nn.Module):
+ def __init__(
+ self,
+ n_feat,
+ dropout_rate,
+ kernel_size,
+ fsmn_shift=0,
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
+ padding=0, groups=n_feat, bias=False)
+ # padding
+ left_padding = (kernel_size - 1) // 2
+ if fsmn_shift > 0:
+ left_padding = left_padding + fsmn_shift
+ right_padding = kernel_size - 1 - left_padding
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+ def forward(self, inputs, mask, mask_shfit_chunk=None):
+ b, t, d = inputs.size()
+ if mask is not None:
+ mask = torch.reshape(mask, (b, -1, 1))
+ if mask_shfit_chunk is not None:
+ mask = mask * mask_shfit_chunk
+
+ inputs = inputs * mask
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x = x + inputs
+ x = self.dropout(x)
+ return x * mask
+
+
+class EncoderLayer(torch.nn.Module):
+ def __init__(
+ self,
+ in_size,
+ size,
+ feed_forward,
+ fsmn_block,
+ dropout_rate=0.0
+ ):
+ super().__init__()
+ self.in_size = in_size
+ self.size = size
+ self.ffn = feed_forward
+ self.memory = fsmn_block
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ mask: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # xs_pad in Batch, Time, Dim
+
+ context = self.ffn(xs_pad)[0]
+ memory = self.memory(context, mask)
+
+ memory = self.dropout(memory)
+ if self.in_size == self.size:
+ return memory + xs_pad, mask
+
+ return memory, mask
+
+
+class FsmnEncoder(AbsEncoder):
+ """Encoder using Fsmn
+ """
+
+ def __init__(self,
+ in_units,
+ filter_size,
+ fsmn_num_layers,
+ dnn_num_layers,
+ num_memory_units=512,
+ ffn_inner_dim=2048,
+ dropout_rate=0.0,
+ shift=0,
+ position_encoder=None,
+ sample_rate=1,
+ out_units=None,
+ tf2torch_tensor_name_prefix_torch="post_net",
+ tf2torch_tensor_name_prefix_tf="EAND/post_net"
+ ):
+ """Initializes the parameters of the encoder.
+
+ Args:
+ filter_size: the total order of memory block
+ fsmn_num_layers: The number of fsmn layers.
+ dnn_num_layers: The number of dnn layers
+ num_units: The number of memory units.
+ ffn_inner_dim: The number of units of the inner linear transformation
+ in the feed forward layer.
+ dropout_rate: The probability to drop units from the outputs.
+ shift: left padding, to control delay
+ position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
+ apply on inputs or ``None``.
+ """
+ super(FsmnEncoder, self).__init__()
+ self.in_units = in_units
+ self.filter_size = filter_size
+ self.fsmn_num_layers = fsmn_num_layers
+ self.dnn_num_layers = dnn_num_layers
+ self.num_memory_units = num_memory_units
+ self.ffn_inner_dim = ffn_inner_dim
+ self.dropout_rate = dropout_rate
+ self.shift = shift
+ if not isinstance(shift, list):
+ self.shift = [shift for _ in range(self.fsmn_num_layers)]
+ self.sample_rate = sample_rate
+ if not isinstance(sample_rate, list):
+ self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
+ self.position_encoder = position_encoder
+ self.dropout = nn.Dropout(dropout_rate)
+ self.out_units = out_units
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+ self.fsmn_layers = repeat(
+ self.fsmn_num_layers,
+ lambda lnum: EncoderLayer(
+ in_units if lnum == 0 else num_memory_units,
+ num_memory_units,
+ FsmnFeedForward(
+ in_units if lnum == 0 else num_memory_units,
+ ffn_inner_dim,
+ num_memory_units,
+ 1,
+ dropout_rate
+ ),
+ FsmnBlock(
+ num_memory_units,
+ dropout_rate,
+ filter_size,
+ self.shift[lnum]
+ )
+ ),
+ )
+
+ self.dnn_layers = repeat(
+ dnn_num_layers,
+ lambda lnum: FsmnFeedForward(
+ num_memory_units,
+ ffn_inner_dim,
+ num_memory_units,
+ 1,
+ dropout_rate,
+ )
+ )
+ if out_units is not None:
+ self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
+
+ def output_size(self) -> int:
+ return self.num_memory_units
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ inputs = xs_pad
+ if self.position_encoder is not None:
+ inputs = self.position_encoder(inputs)
+
+ inputs = self.dropout(inputs)
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ inputs = self.fsmn_layers(inputs, masks)[0]
+ inputs = self.dnn_layers(inputs)[0]
+
+ if self.out_units is not None:
+ inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
+
+ return inputs, ilens, None
+
+ def gen_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+ # torch: conv1d.weight in "out_channel in_channel kernel_size"
+ # tf : conv1d.weight in "kernel_size in_channel out_channel"
+ # torch: linear.weight in "out_channel in_channel"
+ # tf : dense.weight in "in_channel out_channel"
+ # for fsmn_layers
+ "{}.fsmn_layers.layeridx.ffn.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.fsmn_layers.layeridx.ffn.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/fsmn_layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.fsmn_layers.layeridx.ffn.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.fsmn_layers.layeridx.ffn.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/fsmn_layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+ "{}.fsmn_layers.layeridx.ffn.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/fsmn_layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+ "{}.fsmn_layers.layeridx.memory.fsmn_block.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/fsmn_layer_layeridx/memory/depth_conv_w".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 2, 0),
+ }, # (1, 31, 512, 1) -> (31, 512, 1) -> (512, 1, 31)
+
+ # for dnn_layers
+ "{}.dnn_layers.layeridx.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.dnn_layers.layeridx.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.dnn_layers.layeridx.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.dnn_layers.layeridx.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+ "{}.dnn_layers.layeridx.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+
+ }
+ if self.out_units is not None:
+ # add output layer
+ map_dict_local.update({
+ "{}.conv1d.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (2, 1, 0),
+ },
+ "{}.conv1d.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ })
+
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ if name.startswith(self.tf2torch_tensor_name_prefix_torch):
+ # process special (first and last) layers
+ if name in map_dict:
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ # process general layers
+ else:
+ # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
+ names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ else:
+ logging.warning("{} is missed from tf checkpoint".format(name))
+
+ return var_dict_torch_update
diff --git a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
new file mode 100644
index 0000000..443b37a
--- /dev/null
+++ b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py
@@ -0,0 +1,480 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
+from typeguard import check_argument_types
+import numpy as np
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
+from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.multi_layer_conv import Conv1dLinear
+from funasr.modules.multi_layer_conv import MultiLayeredConv1d
+from funasr.modules.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import Conv2dSubsampling
+from funasr.modules.subsampling import Conv2dSubsampling2
+from funasr.modules.subsampling import Conv2dSubsampling6
+from funasr.modules.subsampling import Conv2dSubsampling8
+from funasr.modules.subsampling import TooShortUttError
+from funasr.modules.subsampling import check_short_utt
+from funasr.models.ctc import CTC
+from funasr.models.encoder.abs_encoder import AbsEncoder
+
+
+class EncoderLayer(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(in_size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.in_size = in_size
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.dropout_rate = dropout_rate
+
+ def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ return x, mask
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ if self.in_size == self.size:
+ x = residual + stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ else:
+ x = stoch_layer_coeff * self.dropout(
+ self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ )
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, mask, cache, mask_att_chunk_encoder
+
+
+class SelfAttentionEncoder(AbsEncoder):
+ """
+ author: Speech Lab, Alibaba Group, China
+ Self attention encoder in OpenNMT framework
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ tf2torch_tensor_name_prefix_torch: str = "encoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+ out_units=None,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self._output_size = output_size
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ SinusoidalPositionEncoder(),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ elif input_layer == "pe":
+ self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "null":
+ self.embed = None
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+ self.normalize_before = normalize_before
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d":
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ elif positionwise_layer_type == "conv1d-linear":
+ positionwise_layer = Conv1dLinear
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ positionwise_conv_kernel_size,
+ dropout_rate,
+ )
+ else:
+ raise NotImplementedError("Support only linear or conv1d.")
+
+ self.encoders = repeat(
+ num_blocks,
+ lambda lnum: EncoderLayer(
+ output_size,
+ output_size,
+ MultiHeadSelfAttention(
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ ),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ) if lnum > 0 else EncoderLayer(
+ input_size,
+ output_size,
+ MultiHeadSelfAttention(
+ attention_heads,
+ input_size if input_layer == "pe" or input_layer == "null" else output_size,
+ output_size,
+ attention_dropout_rate,
+ ),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+ self.dropout = nn.Dropout(dropout_rate)
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.out_units = out_units
+ if out_units is not None:
+ self.output_linear = nn.Linear(output_size, out_units)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ prev_states: Not to be used now.
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ xs_pad *= self.output_size()**0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ elif (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ else:
+ xs_pad = self.embed(xs_pad)
+
+ xs_pad = self.dropout(xs_pad)
+ # encoder_outs = self.encoders0(xs_pad, masks)
+ # xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, masks)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, masks)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ if self.out_units is not None:
+ xs_pad = self.output_linear(xs_pad)
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
+
+ def gen_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+ # cicd
+ # torch: conv1d.weight in "out_channel in_channel kernel_size"
+ # tf : conv1d.weight in "kernel_size in_channel out_channel"
+ # torch: linear.weight in "out_channel in_channel"
+ # tf : dense.weight in "in_channel out_channel"
+ "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (768,256),(1,256,768)
+ "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (768,),(768,)
+ "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # ffn
+ "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+ "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # out norm
+ "{}.after_norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.after_norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ }
+ if self.out_units is not None:
+ map_dict_local.update({
+ "{}.output_linear.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ },
+ "{}.output_linear.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ })
+
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ if name.startswith(self.tf2torch_tensor_name_prefix_torch):
+ # process special (first and last) layers
+ if name in map_dict:
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ # process general layers
+ else:
+ # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
+ names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ else:
+ logging.warning("{} is missed from tf checkpoint".format(name))
+
+ return var_dict_torch_update
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 66e446c..952ce15 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -1,7 +1,11 @@
import torch
from torch.nn import functional as F
from funasr.models.encoder.abs_encoder import AbsEncoder
-from typing import Tuple
+from typing import Tuple, Optional
+from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
+from collections import OrderedDict
+import logging
+import numpy as np
class BasicLayer(torch.nn.Module):
@@ -116,10 +120,18 @@
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
self.resnet0_bn = torch.nn.BatchNorm2d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
+ self.time_ds_ratio = 8
+
def output_size(self) -> int:
return self.num_nodes_pooling_layer
- def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
features = xs_pad
assert features.size(-1) == self.input_size, \
"Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
@@ -141,4 +153,463 @@
features = F.relu(features)
features = self.resnet0_bn(features)
- return features, ilens // 8
+ return features, resnet_out_lens
+
+# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
+# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
+class ResNet34_SP_L2Reg(AbsEncoder):
+ def __init__(
+ self,
+ input_size,
+ use_head_conv=True,
+ batchnorm_momentum=0.5,
+ use_head_maxpool=False,
+ num_nodes_pooling_layer=256,
+ layers_in_block=(3, 4, 6, 3),
+ filters_in_block=(32, 64, 128, 256),
+ tf2torch_tensor_name_prefix_torch="encoder",
+ tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
+ tf_train_steps=720000,
+ ):
+ super(ResNet34_SP_L2Reg, self).__init__()
+
+ self.use_head_conv = use_head_conv
+ self.use_head_maxpool = use_head_maxpool
+ self.num_nodes_pooling_layer = num_nodes_pooling_layer
+ self.layers_in_block = layers_in_block
+ self.filters_in_block = filters_in_block
+ self.input_size = input_size
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.tf_train_steps = tf_train_steps
+
+ pre_filters = filters_in_block[0]
+ if use_head_conv:
+ self.pre_conv = torch.nn.Conv2d(1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros")
+ self.pre_conv_bn = torch.nn.BatchNorm2d(pre_filters, eps=1e-3, momentum=batchnorm_momentum)
+
+ if use_head_maxpool:
+ self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
+
+ for i in range(len(layers_in_block)):
+ if i == 0:
+ in_filters = pre_filters if self.use_head_conv else 1
+ else:
+ in_filters = filters_in_block[i-1]
+
+ block = BasicBlock(in_filters,
+ filters=filters_in_block[i],
+ num_layer=layers_in_block[i],
+ stride=1 if i == 0 else 2,
+ bn_momentum=batchnorm_momentum)
+ self.add_module("block_{}".format(i), block)
+
+ self.resnet0_dense = torch.nn.Conv1d(filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1)
+ self.resnet0_bn = torch.nn.BatchNorm1d(num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum)
+
+ self.time_ds_ratio = 8
+
+ def output_size(self) -> int:
+ return self.num_nodes_pooling_layer
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ features = xs_pad
+ assert features.size(-1) == self.input_size, \
+ "Dimension of features {} doesn't match the input_size {}.".format(features.size(-1), self.input_size)
+ features = torch.unsqueeze(features, dim=1)
+ if self.use_head_conv:
+ features = self.pre_conv(features)
+ features = self.pre_conv_bn(features)
+ features = F.relu(features)
+
+ if self.use_head_maxpool:
+ features = self.head_maxpool(features)
+
+ resnet_outs, resnet_out_lens = features, ilens
+ for i in range(len(self.layers_in_block)):
+ block = self._modules["block_{}".format(i)]
+ resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
+
+ # B, C, T, F
+ bb, cc, tt, ff = resnet_outs.shape
+ resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff*cc, tt])
+ features = self.resnet0_dense(resnet_outs)
+ features = F.relu(features)
+ features = self.resnet0_bn(features)
+
+ return features, resnet_out_lens
+
+ def gen_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ train_steps = self.tf_train_steps
+ map_dict_local = {
+ # torch: conv1d.weight in "out_channel in_channel kernel_size"
+ # tf : conv1d.weight in "kernel_size in_channel out_channel"
+ # torch: linear.weight in "out_channel in_channel"
+ # tf : dense.weight in "in_channel out_channel"
+ "{}.pre_conv.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (3, 2, 0, 1),
+ },
+ "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
+ }
+ for layer_idx in range(3):
+ map_dict_local.update({
+ "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": (2, 1, 0) if layer_idx == 0 else (1, 0),
+ },
+ "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
+ })
+
+ for block_idx in range(len(self.layers_in_block)):
+ for layer_idx in range(self.layers_in_block[block_idx]):
+ for i in ["1", "2", "_sc"]:
+ map_dict_local.update({
+ "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": (3, 2, 0, 1),
+ },
+ "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
+ })
+
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ if name.startswith(self.tf2torch_tensor_name_prefix_torch):
+ if name in map_dict:
+ if "num_batches_tracked" not in name:
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ else:
+ var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
+ logging.info("torch tensor: {}, manually assigning to: {}".format(
+ name, map_dict[name]
+ ))
+ else:
+ logging.warning("{} is missed from tf checkpoint".format(name))
+
+ return var_dict_torch_update
+
+
+
+class ResNet34Diar(ResNet34):
+ def __init__(
+ self,
+ input_size,
+ embedding_node="resnet1_dense",
+ use_head_conv=True,
+ batchnorm_momentum=0.5,
+ use_head_maxpool=False,
+ num_nodes_pooling_layer=256,
+ layers_in_block=(3, 4, 6, 3),
+ filters_in_block=(32, 64, 128, 256),
+ num_nodes_resnet1=256,
+ num_nodes_last_layer=256,
+ pooling_type="window_shift",
+ pool_size=20,
+ stride=1,
+ tf2torch_tensor_name_prefix_torch="encoder",
+ tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
+ ):
+ super(ResNet34Diar, self).__init__(
+ input_size,
+ use_head_conv=use_head_conv,
+ batchnorm_momentum=batchnorm_momentum,
+ use_head_maxpool=use_head_maxpool,
+ num_nodes_pooling_layer=num_nodes_pooling_layer,
+ layers_in_block=layers_in_block,
+ filters_in_block=filters_in_block,
+ )
+
+ self.embedding_node = embedding_node
+ self.num_nodes_resnet1 = num_nodes_resnet1
+ self.num_nodes_last_layer = num_nodes_last_layer
+ self.pooling_type = pooling_type
+ self.pool_size = pool_size
+ self.stride = stride
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+ self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
+ self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
+
+ self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
+ self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
+
+ def output_size(self) -> int:
+ if self.embedding_node.startswith("resnet1"):
+ return self.num_nodes_resnet1
+ elif self.embedding_node.startswith("resnet2"):
+ return self.num_nodes_last_layer
+
+ return self.num_nodes_pooling_layer
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+
+ endpoints = OrderedDict()
+ res_out, ilens = super().forward(xs_pad, ilens)
+ endpoints["resnet0_bn"] = res_out
+ if self.pooling_type == "frame_gsp":
+ features = statistic_pooling(res_out, ilens, (3, ))
+ else:
+ features, ilens = windowed_statistic_pooling(res_out, ilens, (2, 3), self.pool_size, self.stride)
+ features = features.transpose(1, 2)
+ endpoints["pooling"] = features
+
+ features = self.resnet1_dense(features)
+ endpoints["resnet1_dense"] = features
+ features = F.relu(features)
+ endpoints["resnet1_relu"] = features
+ features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
+ endpoints["resnet1_bn"] = features
+
+ features = self.resnet2_dense(features)
+ endpoints["resnet2_dense"] = features
+ features = F.relu(features)
+ endpoints["resnet2_relu"] = features
+ features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
+ endpoints["resnet2_bn"] = features
+
+ return endpoints[self.embedding_node], ilens, None
+
+ def gen_tf2torch_map_dict(self):
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ train_steps = 300000
+ map_dict_local = {
+ # torch: conv1d.weight in "out_channel in_channel kernel_size"
+ # tf : conv1d.weight in "kernel_size in_channel out_channel"
+ # torch: linear.weight in "out_channel in_channel"
+ # tf : dense.weight in "in_channel out_channel"
+ "{}.pre_conv.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv/kernel".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": (3, 2, 0, 1),
+ },
+ "{}.pre_conv_bn.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.running_mean".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/moving_mean".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.running_var".format(tensor_name_prefix_torch):
+ {"name": "{}/pre_conv_bn/moving_variance".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.pre_conv_bn.num_batches_tracked".format(tensor_name_prefix_torch): train_steps
+ }
+ for layer_idx in range(3):
+ map_dict_local.update({
+ "{}.resnet{}_dense.weight".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_dense/kernel".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": (3, 2, 0, 1) if layer_idx == 0 else (1, 0),
+ },
+ "{}.resnet{}_dense.bias".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_dense/bias".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.weight".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/gamma".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.bias".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/beta".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.running_mean".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/moving_mean".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.running_var".format(tensor_name_prefix_torch, layer_idx):
+ {"name": "{}/resnet{}_bn/moving_variance".format(tensor_name_prefix_tf, layer_idx),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.resnet{}_bn.num_batches_tracked".format(tensor_name_prefix_torch, layer_idx): train_steps
+ })
+
+ for block_idx in range(len(self.layers_in_block)):
+ for layer_idx in range(self.layers_in_block[block_idx]):
+ for i in ["1", "2", "_sc"]:
+ map_dict_local.update({
+ "{}.block_{}.layer_{}.conv{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/conv{}/kernel".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": (3, 2, 0, 1),
+ },
+ "{}.block_{}.layer_{}.bn{}.weight".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/gamma".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.bias".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/beta".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.running_mean".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/moving_mean".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.running_var".format(tensor_name_prefix_torch, block_idx, layer_idx, i):
+ {"name": "{}/block_{}/layer_{}/bn{}/moving_variance".format(tensor_name_prefix_tf, block_idx, layer_idx, i),
+ "squeeze": None,
+ "transpose": None,
+ },
+ "{}.block_{}.layer_{}.bn{}.num_batches_tracked".format(tensor_name_prefix_torch, block_idx, layer_idx, i): train_steps,
+ })
+
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+
+ var_dict_torch_update = dict()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ if name.startswith(self.tf2torch_tensor_name_prefix_torch):
+ if name in map_dict:
+ if "num_batches_tracked" not in name:
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), \
+ "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[name].size(), data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
+ name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
+ ))
+ else:
+ var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu")
+ logging.info("torch tensor: {}, manually assigning to: {}".format(
+ name, map_dict[name]
+ ))
+ else:
+ logging.warning("{} is missed from tf checkpoint".format(name))
+
+ return var_dict_torch_update
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 57c5976..7a6425b 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -90,7 +90,9 @@
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
- dither: float = 1.0
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@@ -105,6 +107,8 @@
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
+ self.snip_edges = snip_edges
+ self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -119,7 +123,8 @@
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
- waveform = waveform * (1 << 15)
+ if self.upsacle_samples:
+ waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
@@ -128,7 +133,8 @@
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
- sample_frequency=self.fs)
+ sample_frequency=self.fs,
+ snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
diff --git a/funasr/models/pooling/statistic_pooling.py b/funasr/models/pooling/statistic_pooling.py
index eeaed7d..dc8c98f 100644
--- a/funasr/models/pooling/statistic_pooling.py
+++ b/funasr/models/pooling/statistic_pooling.py
@@ -2,7 +2,10 @@
from typing import Tuple
from typing import Union
from funasr.modules.nets_utils import make_non_pad_mask
+from torch.nn import functional as F
+import math
+VAR2STD_EPSILON = 1e-12
class StatisticPooling(torch.nn.Module):
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
@@ -34,3 +37,59 @@
stat_pooling = torch.cat([mean, stddev], dim=1)
return stat_pooling
+
+ def convert_tf2torch(self, var_dict_tf, var_dict_torch):
+ return {}
+
+
+def statistic_pooling(
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor = None,
+ pooling_dim: Tuple = (2, 3)
+) -> torch.Tensor:
+ # xs_pad in (Batch, Channel, Time, Frequency)
+
+ if ilens is None:
+ seq_mask = torch.ones_like(xs_pad).to(xs_pad)
+ else:
+ seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
+ mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
+ torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
+ squared_difference = torch.pow(xs_pad - mean, 2.0)
+ variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
+ torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
+ for i in reversed(pooling_dim):
+ mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
+
+ value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float()
+ variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON
+ stddev = torch.sqrt(variance)
+
+ stat_pooling = torch.cat([mean, stddev], dim=1)
+
+ return stat_pooling
+
+
+def windowed_statistic_pooling(
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor = None,
+ pooling_dim: Tuple = (2, 3),
+ pooling_size: int = 20,
+ pooling_stride: int = 1
+) -> Tuple[torch.Tensor, int]:
+ # xs_pad in (Batch, Channel, Time, Frequency)
+
+ tt = xs_pad.shape[2]
+ num_chunk = int(math.ceil(tt / pooling_stride))
+ pad = pooling_size // 2
+ features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
+ stat_list = []
+
+ for i in range(num_chunk):
+ # B x C
+ st, ed = i*pooling_stride, i*pooling_stride+pooling_size
+ stat = statistic_pooling(features[:, :, st: ed, :], pooling_dim=pooling_dim)
+ stat_list.append(stat.unsqueeze(2))
+
+ # B x C x T
+ return torch.cat(stat_list, dim=2), ilens / pooling_stride
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index c34759d..5615373 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -544,9 +544,8 @@
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-
- def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
- target_label_length=None, token_num=None):
+
+ def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index e3ad56a..c47d96d 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -622,4 +622,108 @@
q_h, k_h, v_h = self.forward_qkv(x, memory)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- return self.forward_attention(v_h, scores, memory_mask)
\ No newline at end of file
+ return self.forward_attention(v_h, scores, memory_mask)
+
+
+class MultiHeadSelfAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadSelfAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs
diff --git a/funasr/modules/multi_layer_conv.py b/funasr/modules/multi_layer_conv.py
index 5fb0717..9d269ab 100644
--- a/funasr/modules/multi_layer_conv.py
+++ b/funasr/modules/multi_layer_conv.py
@@ -63,6 +63,58 @@
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
+class FsmnFeedForward(torch.nn.Module):
+ """Position-wise feed forward for FSMN blocks.
+
+ This is a module of multi-leyered conv1d designed
+ to replace position-wise feed-forward network
+ in FSMN block.
+ """
+
+ def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
+ """Initialize FsmnFeedForward module.
+
+ Args:
+ in_chans (int): Number of input channels.
+ hidden_chans (int): Number of hidden channels.
+ out_chans (int): Number of output channels.
+ kernel_size (int): Kernel size of conv1d.
+ dropout_rate (float): Dropout rate.
+
+ """
+ super(FsmnFeedForward, self).__init__()
+ self.w_1 = torch.nn.Conv1d(
+ in_chans,
+ hidden_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.w_2 = torch.nn.Conv1d(
+ hidden_chans,
+ out_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=False
+ )
+ self.norm = torch.nn.LayerNorm(hidden_chans)
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ def forward(self, x, ilens=None):
+ """Calculate forward propagation.
+
+ Args:
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
+
+ Returns:
+ torch.Tensor: Batch of output tensors (B, T, out_chans).
+
+ """
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
+
+
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 7899400..02311fd 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -43,6 +43,7 @@
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
from funasr.optimizers.sgd import SGD
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.samplers.build_batch_sampler import BATCH_TYPES
@@ -1272,6 +1273,52 @@
if args.dry_run:
pass
+ elif args.collect_stats:
+ # Perform on collect_stats mode. This mode has two roles
+ # - Derive the length and dimension of all input data
+ # - Accumulate feats, square values, and the length for whitening
+
+ if args.valid_batch_size is None:
+ args.valid_batch_size = args.batch_size
+
+ if len(args.train_shape_file) != 0:
+ train_key_file = args.train_shape_file[0]
+ else:
+ train_key_file = None
+ if len(args.valid_shape_file) != 0:
+ valid_key_file = args.valid_shape_file[0]
+ else:
+ valid_key_file = None
+
+ collect_stats(
+ model=model,
+ train_iter=cls.build_streaming_iterator(
+ data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+ key_file=train_key_file,
+ batch_size=args.batch_size,
+ dtype=args.train_dtype,
+ num_workers=args.num_workers,
+ allow_variable_data_keys=args.allow_variable_data_keys,
+ ngpu=args.ngpu,
+ preprocess_fn=cls.build_preprocess_fn(args, train=False),
+ collate_fn=cls.build_collate_fn(args, train=False),
+ ),
+ valid_iter=cls.build_streaming_iterator(
+ data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+ key_file=valid_key_file,
+ batch_size=args.valid_batch_size,
+ dtype=args.train_dtype,
+ num_workers=args.num_workers,
+ allow_variable_data_keys=args.allow_variable_data_keys,
+ ngpu=args.ngpu,
+ preprocess_fn=cls.build_preprocess_fn(args, train=False),
+ collate_fn=cls.build_collate_fn(args, train=False),
+ ),
+ output_dir=output_dir,
+ ngpu=args.ngpu,
+ log_interval=args.log_interval,
+ write_collected_feats=args.write_collected_feats,
+ )
else:
logging.info("Training args: {}".format(args))
# 6. Loads pre-trained model
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 1b7f152..e62a748 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -37,8 +37,9 @@
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -117,6 +118,7 @@
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
),
type_check=AbsESPnetModel,
default="asr",
@@ -177,6 +179,7 @@
fsmn_scama_opt=FsmnDecoderSCAMAOpt,
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@@ -1098,5 +1101,8 @@
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
new file mode 100644
index 0000000..f3212f1
--- /dev/null
+++ b/funasr/tasks/diar.py
@@ -0,0 +1,585 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+import yaml
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.layers.label_aggregation import LabelAggregate
+from funasr.models.ctc import CTC
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
+from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
+from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.tasks.abs_task import AbsTask
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+label_aggregator_choices = ClassChoices(
+ "label_aggregator",
+ classes=dict(
+ label_aggregator=LabelAggregate
+ ),
+ type_check=torch.nn.Module,
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ sond=DiarSondModel,
+ ),
+ type_check=AbsESPnetModel,
+ default="sond",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ san=SelfAttentionEncoder,
+ fsmn=FsmnEncoder,
+ conv=ConvEncoder,
+ resnet34=ResNet34Diar,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="resnet34",
+)
+speaker_encoder_choices = ClassChoices(
+ "speaker_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ san=SelfAttentionEncoder,
+ fsmn=FsmnEncoder,
+ conv=ConvEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ type_check=AbsEncoder,
+ default=None,
+ optional=True
+)
+cd_scorer_choices = ClassChoices(
+ "cd_scorer",
+ classes=dict(
+ san=SelfAttentionEncoder,
+ ),
+ type_check=AbsEncoder,
+ default=None,
+ optional=True,
+)
+ci_scorer_choices = ClassChoices(
+ "ci_scorer",
+ classes=dict(
+ dot=DotScorer,
+ cosine=CosScorer,
+ ),
+ type_check=torch.nn.Module,
+ default=None,
+ optional=True,
+)
+# decoder is used for output (e.g. post_net in SOND)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ rnn=RNNEncoder,
+ fsmn=FsmnEncoder,
+ ),
+ type_check=torch.nn.Module,
+ default="fsmn",
+)
+
+
+class DiarTask(AbsTask):
+ # If you need more than 1 optimizer, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --speaker_encoder and --speaker_encoder_conf
+ speaker_encoder_choices,
+ # --cd_scorer and cd_scorer_conf
+ cd_scorer_choices,
+ # --ci_scorer and ci_scorer_conf
+ ci_scorer_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(description="Task related")
+
+ # NOTE(kamo): add_arguments(..., required=True) can't be used
+ # to provide --print_config mode. Instead of it, do as
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--seg_dict_file",
+ type=str,
+ default=None,
+ help="seg_dict_file for text processing",
+ )
+ group.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ group = parser.add_argument_group(description="Preprocess related")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="char",
+ choices=["char"],
+ help="The text will be tokenized in the specified level token",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Scale the maximum amplitude to the given value.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of rir scp file.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="THe probability for applying RIR convolution.",
+ )
+ parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability applying Noise adding.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ assert check_argument_types()
+ # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ assert check_argument_types()
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=None,
+ non_linguistic_symbols=None,
+ text_cleaner=None,
+ g2p_type=None,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+ # NOTE(kamo): Check attribute existence for backward compatibility
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ if not inference:
+ retval = ("speech", "profile", "label")
+ else:
+ # Recognition mode
+ retval = ("speech", "profile")
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ retval = ()
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 5. speaker encoder
+ if getattr(args, "speaker_encoder", None) is not None:
+ speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
+ speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
+ else:
+ speaker_encoder = None
+
+ # 6. CI & CD scorer
+ if getattr(args, "ci_scorer", None) is not None:
+ ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
+ ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
+ else:
+ ci_scorer = None
+
+ if getattr(args, "cd_scorer", None) is not None:
+ cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
+ cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
+ else:
+ cd_scorer = None
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(**args.decoder_conf)
+
+ if getattr(args, "label_aggregator", None) is not None:
+ label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
+ label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
+ else:
+ label_aggregator = None
+
+ # 9. Build model
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ label_aggregator=label_aggregator,
+ encoder=encoder,
+ speaker_encoder=speaker_encoder,
+ ci_scorer=ci_scorer,
+ cd_scorer=cd_scorer,
+ decoder=decoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
+
+ # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
+ @classmethod
+ def build_model_from_file(
+ cls,
+ config_file: Union[Path, str] = None,
+ model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ ):
+ """Build model from the files.
+
+ This method is used for inference or fine-tuning.
+
+ Args:
+ config_file: The yaml file saved when training.
+ model_file: The model file saved when training.
+ cmvn_file: The cmvn file for front-end
+ device: Device type, "cpu", "cuda", or "cuda:N".
+
+ """
+ assert check_argument_types()
+ if config_file is None:
+ assert model_file is not None, (
+ "The argument 'model_file' must be provided "
+ "if the argument 'config_file' is not specified."
+ )
+ config_file = Path(model_file).parent / "config.yaml"
+ else:
+ config_file = Path(config_file)
+
+ with config_file.open("r", encoding="utf-8") as f:
+ args = yaml.safe_load(f)
+ if cmvn_file is not None:
+ args["cmvn_file"] = cmvn_file
+ args = argparse.Namespace(**args)
+ model = cls.build_model(args)
+ if not isinstance(model, AbsESPnetModel):
+ raise RuntimeError(
+ f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ )
+ model.to(device)
+ model_dict = dict()
+ model_name_pth = None
+ if model_file is not None:
+ logging.info("model_file is {}".format(model_file))
+ if device == "cuda":
+ device = f"cuda:{torch.cuda.current_device()}"
+ model_dir = os.path.dirname(model_file)
+ model_name = os.path.basename(model_file)
+ if "model.ckpt-" in model_name or ".bin" in model_name:
+ if ".bin" in model_name:
+ model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
+ else:
+ model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
+ if os.path.exists(model_name_pth):
+ logging.info("model_file is load from pth: {}".format(model_name_pth))
+ model_dict = torch.load(model_name_pth, map_location=device)
+ else:
+ model_dict = cls.convert_tf2torch(model, model_file)
+ model.load_state_dict(model_dict)
+ else:
+ model_dict = torch.load(model_file, map_location=device)
+ model.load_state_dict(model_dict)
+ if model_name_pth is not None and not os.path.exists(model_name_pth):
+ torch.save(model_dict, model_name_pth)
+ logging.info("model_file is saved to pth: {}".format(model_name_pth))
+
+ return model, args
+
+ @classmethod
+ def convert_tf2torch(
+ cls,
+ model,
+ ckpt,
+ ):
+ logging.info("start convert tf model to torch model")
+ from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
+ var_dict_tf = load_tf_dict(ckpt)
+ var_dict_torch = model.state_dict()
+ var_dict_torch_update = dict()
+ # speech encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # speaker encoder
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # cd scorer
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # ci scorer
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+
+ return var_dict_torch_update
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 46b9fe0..608c1d3 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -58,7 +58,7 @@
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
- required += ["token_list"]
+ # required += ["token_list"]
group.add_argument(
"--token_list",
diff --git a/funasr/utils/job_runner.py b/funasr/utils/job_runner.py
new file mode 100644
index 0000000..a35d49c
--- /dev/null
+++ b/funasr/utils/job_runner.py
@@ -0,0 +1,103 @@
+from __future__ import print_function
+from multiprocessing import Pool
+import argparse
+from tqdm import tqdm
+import math
+
+
+class MultiProcessRunner:
+ def __init__(self, fn):
+ self.args = None
+ self.process = fn
+
+ def run(self):
+ parser = argparse.ArgumentParser("")
+ # Task-independent options
+ parser.add_argument("--nj", type=int, default=16)
+ parser.add_argument("--debug", action="store_true", default=False)
+ parser.add_argument("--no_pbar", action="store_true", default=False)
+ parser.add_argument("--verbose", action="store_ture", default=False)
+
+ task_list, args = self.prepare(parser)
+ result_list = self.pool_run(task_list, args)
+ self.post(result_list, args)
+
+ def prepare(self, parser):
+ raise NotImplementedError("Please implement the prepare function.")
+
+ def post(self, result_list, args):
+ raise NotImplementedError("Please implement the post function.")
+
+ def pool_run(self, tasks, args):
+ results = []
+ if args.debug:
+ one_result = self.process(tasks[0])
+ results.append(one_result)
+ else:
+ pool = Pool(args.nj)
+ for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
+ results.append(one_result)
+ pool.close()
+
+ return results
+
+
+class MultiProcessRunnerV2:
+ def __init__(self, fn):
+ self.args = None
+ self.process = fn
+
+ def run(self):
+ parser = argparse.ArgumentParser("")
+ # Task-independent options
+ parser.add_argument("--nj", type=int, default=16)
+ parser.add_argument("--debug", action="store_true", default=False)
+ parser.add_argument("--no_pbar", action="store_true", default=False)
+ parser.add_argument("--verbose", action="store_true", default=False)
+
+ task_list, args = self.prepare(parser)
+ chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
+ if args.verbose:
+ print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
+ subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)]
+ result_list = self.pool_run(subtask_list, args)
+ self.post(result_list, args)
+
+ def prepare(self, parser):
+ raise NotImplementedError("Please implement the prepare function.")
+
+ def post(self, result_list, args):
+ raise NotImplementedError("Please implement the post function.")
+
+ def pool_run(self, tasks, args):
+ results = []
+ if args.debug:
+ one_result = self.process(tasks[0])
+ results.append(one_result)
+ else:
+ pool = Pool(args.nj)
+ for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
+ results.append(one_result)
+ pool.close()
+
+ return results
+
+
+class MultiProcessRunnerV3(MultiProcessRunnerV2):
+ def run(self):
+ parser = argparse.ArgumentParser("")
+ # Task-independent options
+ parser.add_argument("--nj", type=int, default=16)
+ parser.add_argument("--debug", action="store_true", default=False)
+ parser.add_argument("--no_pbar", action="store_true", default=False)
+ parser.add_argument("--verbose", action="store_true", default=False)
+ parser.add_argument("--sr", type=int, default=16000)
+
+ task_list, shared_param, args = self.prepare(parser)
+ chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
+ if args.verbose:
+ print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
+ subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
+ for i in range(args.nj)]
+ result_list = self.pool_run(subtask_list, args)
+ self.post(result_list, args)
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
new file mode 100644
index 0000000..f27a63c
--- /dev/null
+++ b/funasr/utils/misc.py
@@ -0,0 +1,48 @@
+import io
+from collections import OrderedDict
+import numpy as np
+
+
+def statistic_model_parameters(model, prefix=None):
+ var_dict = model.state_dict()
+ numel = 0
+ for i, key in enumerate(sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))):
+ if prefix is None or key.startswith(prefix):
+ numel += var_dict[key].numel()
+ return numel
+
+
+def int2vec(x, vec_dim=8, dtype=np.int):
+ b = ('{:0' + str(vec_dim) + 'b}').format(x)
+ # little-endian order: lower bit first
+ return (np.array(list(b)[::-1]) == '1').astype(dtype)
+
+
+def seq2arr(seq, vec_dim=8):
+ return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
+
+
+def load_scp_as_dict(scp_path, value_type='str', kv_sep=" "):
+ with io.open(scp_path, 'r', encoding='utf-8') as f:
+ ret_dict = OrderedDict()
+ for one_line in f.readlines():
+ one_line = one_line.strip()
+ pos = one_line.find(kv_sep)
+ key, value = one_line[:pos], one_line[pos + 1:]
+ if value_type == 'list':
+ value = value.split(' ')
+ ret_dict[key] = value
+ return ret_dict
+
+
+def load_scp_as_list(scp_path, value_type='str', kv_sep=" "):
+ with io.open(scp_path, 'r', encoding='utf8') as f:
+ ret_dict = []
+ for one_line in f.readlines():
+ one_line = one_line.strip()
+ pos = one_line.find(kv_sep)
+ key, value = one_line[:pos], one_line[pos + 1:]
+ if value_type == 'list':
+ value = value.split(' ')
+ ret_dict.append((key, value))
+ return ret_dict
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index 4da0d59..575fb90 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -232,5 +232,9 @@
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != ' ':
+ real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
- return sentence
+ return sentence, real_word_lists
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 3afaa40..33d1255 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -86,14 +86,51 @@
else:
return time_stamp_list
-
-def time_stamp_lfr6_advance(tst: List, text: str):
- # advanced timestamp prediction for BiCIF_Paraformer using upsampled alphas
- ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = tst
- if text.endswith('</s>'):
- text = text[:-4]
+def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
+ START_END_THRESHOLD = 5
+ TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
+ if len(us_alphas.shape) == 3:
+ alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
else:
- text = text[:-1]
- logging.warning("found text does not end with </s>")
- assert int(ds_alphas.sum() + 1e-4) - 1 == len(text)
-
+ alphas, cif_peak = us_alphas, us_cif_peak
+ num_frames = cif_peak.shape[0]
+ if char_list[-1] == '</s>':
+ char_list = char_list[:-1]
+ # char_list = [i for i in text]
+ timestamp_list = []
+ # for bicif model trained with large data, cif2 actually fires when a character starts
+ # so treat the frames between two peaks as the duration of the former token
+ fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
+ num_peak = len(fire_place)
+ assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+ # begin silence
+ if fire_place[0] > START_END_THRESHOLD:
+ char_list.insert(0, '<sil>')
+ timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
+ # tokens timestamp
+ for i in range(len(fire_place)-1):
+ # the peak is always a little ahead of the start time
+ # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
+ timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
+ # cut the duration to token and sil of the 0-weight frames last long
+ # tail token and end silence
+ if num_frames - fire_place[-1] > START_END_THRESHOLD:
+ _end = (num_frames + fire_place[-1]) / 2
+ timestamp_list[-1][1] = _end*TIME_RATE
+ timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
+ char_list.append("<sil>")
+ else:
+ timestamp_list[-1][1] = num_frames*TIME_RATE
+ if begin_time: # add offset time in model with vad
+ for i in range(len(timestamp_list)):
+ timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
+ timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
+ res_txt = ""
+ for char, timestamp in zip(char_list, timestamp_list):
+ res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
+ res = []
+ for char, timestamp in zip(char_list, timestamp_list):
+ if char != '<sil>':
+ res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
+ return res
+
diff --git a/funasr/version.txt b/funasr/version.txt
index 1180819..0ea3a94 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.1.7
+0.2.0
diff --git a/setup.py b/setup.py
index dcaeec0..67206f5 100644
--- a/setup.py
+++ b/setup.py
@@ -12,11 +12,11 @@
requirements = {
"install": [
"setuptools>=38.5.1",
- "configargparse>=1.2.1",
+ # "configargparse>=1.2.1",
"typeguard>=2.7.0",
"humanfriendly",
"scipy>=1.4.1",
- "filelock",
+ # "filelock",
"librosa>=0.8.0",
"jamo==0.4.1", # For kss
"PyYAML>=5.1.2",
@@ -27,13 +27,13 @@
"nltk>=3.4.5",
# ASR
"sentencepiece",
- "ctc-segmentation<1.8,>=1.6.6",
+ # "ctc-segmentation<1.8,>=1.6.6",
# TTS
- "pyworld>=0.2.10",
+ # "pyworld>=0.2.10",
"pypinyin<=0.44.0",
"espnet_tts_frontend",
# ENH
- "ci_sdr",
+ # "ci_sdr",
"pytorch_wpe",
"editdistance==0.5.2",
"tensorboard==1.15",
@@ -43,7 +43,7 @@
],
# train: The modules invoked when training only.
"train": [
- "pillow>=6.1.0",
+ # "pillow>=6.1.0",
"editdistance==0.5.2",
"wandb",
],
@@ -51,18 +51,18 @@
# but are invoked for the python scripts in each recipe
"recipe": [
"espnet_model_zoo",
- "gdown",
- "resampy",
- "pysptk>=0.1.17",
- "morfessor", # for zeroth-korean
- "youtube_dl", # for laborotv
- "nnmnkwii",
- "museval>=0.2.1",
- "pystoi>=0.2.2",
- "mir-eval>=0.6",
- "fastdtw",
- "nara_wpe>=0.0.5",
- "sacrebleu>=1.5.1",
+ # "gdown",
+ # "resampy",
+ # "pysptk>=0.1.17",
+ # "morfessor", # for zeroth-korean
+ # "youtube_dl", # for laborotv
+ # "nnmnkwii",
+ # "museval>=0.2.1",
+ # "pystoi>=0.2.2",
+ # "mir-eval>=0.6",
+ # "fastdtw",
+ # "nara_wpe>=0.0.5",
+ # "sacrebleu>=1.5.1",
],
# all: The modules should be optionally installled due to some reason.
# Please consider moving them to "install" occasionally
@@ -72,7 +72,7 @@
"torch_optimizer",
"fairscale",
"transformers",
- "gtn==0.0.0",
+ # "gtn==0.0.0",
],
"setup": [
"numpy<=1.21.3",
--
Gitblit v1.9.1