From 97a689d65da434345a641a909f13b78e5690c86b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 18 五月 2023 19:35:08 +0800
Subject: [PATCH] Merge pull request #526 from alibaba-damo-academy/dev_infer
---
egs/aishell/data2vec_transformer_finetune/run.sh | 123
egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml | 6
egs/librispeech_100h/conformer/local/spm_train.py | 12
egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/demo.py | 0
funasr/datasets/small_datasets/length_batch_sampler.py | 147
funasr/models/e2e_asr.py | 13
funasr/models/transformer_lm.py | 2
egs/aishell2/transformer/utils/fix_data.sh | 4
funasr/datasets/small_datasets/preprocessor.py | 875 +++
egs/librispeech_100h/conformer/utils | 1
funasr/bin/lm_inference_launch.py | 299 +
funasr/build_utils/build_trainer.py | 820 +++
egs/aishell2/conformer/local/prepare_data.sh | 6
egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/demo.py | 0
egs/aishell2/transformer/utils/compute_cmvn.py | 88
funasr/layers/abs_normalize.py | 2
funasr/tasks/vad.py | 79
egs/aishell2/paraformerbert/local/extract_embeds.sh | 33
funasr/tasks/sa_asr.py | 4
egs/aishell/conformer/run.sh | 118
funasr/tasks/punctuation.py | 2
egs/aishell/transformer/utils/compute_cmvn.sh | 21
egs/librispeech/conformer/local/spm_train.py | 12
egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh | 105
egs/aishell2/transformer/utils/fix_data_feat.sh | 8
funasr/models/encoder/abs_encoder.py | 2
funasr/build_utils/__init__.py | 0
funasr/main_funcs/collect_stats.py | 4
funasr/bin/sa_asr_train.py | 3
funasr/models/encoder/resnet34_encoder.py | 2
funasr/bin/sv_inference_launch.py | 180
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py | 1
funasr/models/encoder/conformer_encoder.py | 4
egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml | 70
funasr/build_utils/build_scheduler.py | 44
egs/aishell/paraformerbert/local/extract_embeds.sh | 14
funasr/models/encoder/sanm_encoder.py | 14
funasr/tasks/diar.py | 12
egs/aishell/paraformer/local/download_and_untar.sh | 105
funasr/tasks/lm.py | 10
egs/aishell/conformer/conf/train_asr_conformer.yaml | 32
funasr/build_utils/build_distributed.py | 38
funasr/bin/punc_inference_launch.py | 190
funasr/tasks/abs_task.py | 56
egs/aishell2/paraformerbert/local/prepare_data.sh | 7
funasr/bin/asr_infer.py | 1834 ++++++
egs/aishell/data2vec_paraformer_finetune/run.sh | 125
funasr/models/encoder/data2vec_encoder.py | 2
funasr/models/e2e_vad.py | 8
egs/librispeech/conformer/conf/train_asr_conformer.yaml | 14
egs/aishell2/paraformer/run.sh | 116
funasr/main_funcs/calculate_all_attentions.py | 4
egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py | 8
funasr/models/e2e_diar_eend_ola.py | 5
funasr/models/e2e_diar_sond.py | 8
egs/aishell2/transformer/conf/train_asr_transformer.yaml | 23
funasr/models/e2e_tp.py | 15
funasr/bin/punc_infer.py | 271 +
egs/aishell2/transformer/utils/compute_cmvn.sh | 21
funasr/bin/asr_train.py | 27
funasr/train/abs_model.py | 142
egs/aishell/transformer/path.sh | 2
funasr/models/frontend/windowing.py | 7
egs/aishell/transformer/utils/cmvn_converter.py | 6
funasr/build_utils/build_diar_model.py | 296 +
egs/aishell2/paraformer/local/prepare_data.sh | 6
funasr/bin/vad_infer.py | 201
egs/aishell2/paraformerbert/run.sh | 142
funasr/models/base_model.py | 17
egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml | 32
funasr/bin/punc_train.py | 4
funasr/datasets/small_datasets/sequence_iter_factory.py | 189
funasr/bin/train.py | 572 ++
docs/reference/build_task.md | 2
egs/aishell/conformer/local/download_and_untar.sh | 105
funasr/build_utils/build_vad_model.py | 77
egs/aishell2/data2vec_pretrain/run.sh | 87
egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml | 6
funasr/build_utils/build_optimizer.py | 28
funasr/models/e2e_asr_transducer.py | 6
funasr/models/encoder/rnn_encoder.py | 3
funasr/models/frontend/wav_frontend_kaldifeat.py | 119
funasr/build_utils/build_pretrain_model.py | 107
funasr/bin/lm_train.py | 3
funasr/bin/vad_inference_launch.py | 285 +
funasr/models/e2e_asr_paraformer.py | 215
egs/aishell2/transformer/utils/combine_cmvn_file.py | 27
egs/librispeech_100h/conformer/path.sh | 0
egs_modelscope/tp/TEMPLATE/infer.py | 3
funasr/bin/diar_inference_launch.py | 399 +
egs/aishell2/transformer/utils/compute_fbank.py | 24
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml | 1
egs/librispeech/conformer/run.sh | 163
egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh | 105
egs/librispeech_100h/conformer/run.sh | 219
funasr/export/test/test_onnx_punc_vadrealtime.py | 2
egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam60.yaml | 6
funasr/utils/prepare_data.py | 226
egs/aishell2/transformer/run.sh | 112
egs/librispeech_100h/conformer/local/spm_encode.py | 98
egs/aishell/paraformer/run.sh | 115
funasr/build_utils/build_dataloader.py | 15
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py | 19
egs/aishell/paraformerbert/run.sh | 131
egs/librispeech_100h/conformer/local/data_prep.sh | 0
egs/aishell/paraformerbert/local/download_and_untar.sh | 105
funasr/layers/global_mvn.py | 5
funasr/build_utils/build_args.py | 93
funasr/models/e2e_sv.py | 8
funasr/datasets/large_datasets/build_dataloader.py | 36
funasr/train/trainer.py | 4
funasr/datasets/small_datasets/dataset.py | 269 +
egs/librispeech/conformer/local/spm_encode.py | 98
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 5
funasr/bin/diar_infer.py | 347 +
funasr/models/frontend/abs_frontend.py | 2
funasr/build_utils/build_punc_model.py | 68
egs/aishell/transformer/run.sh | 127
egs/aishell/paraformerbert/local/aishell_data_prep.sh | 23
funasr/bin/tp_inference_launch.py | 189
docs/academic_recipe/asr_recipe.md | 103
egs/librispeech/conformer/local/download_and_untar.sh | 97
funasr/models/data2vec.py | 6
egs/aishell2/transformer/utils/download_model.py | 20
funasr/build_utils/build_lm_model.py | 57
funasr/build_utils/build_asr_model.py | 423 +
egs/aishell/transformer/utils/compute_cmvn.py | 88
funasr/bin/diar_train.py | 3
egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml | 32
egs/aishell2/transformer/utils/cmvn_converter.py | 51
egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml | 31
egs/aishell/transformer/conf/train_asr_transformer.yaml | 32
funasr/models/vad_realtime_transformer.py | 3
funasr/datasets/large_datasets/dataset.py | 24
egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml | 41
egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam20.yaml | 6
egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam1.yaml | 6
funasr/bin/sv_infer.py | 163
funasr/datasets/small_datasets/collate_fn.py | 93
funasr/datasets/large_datasets/utils/tokenize.py | 2
funasr/models/seq_rnn_lm.py | 3
egs/aishell/transformer/utils/combine_cmvn_file.py | 27
funasr/fileio/sound_scp.py | 15
funasr/version.txt | 2
funasr/models/e2e_asr_mfcca.py | 138
egs/librispeech_100h/conformer/local/download_and_untar.sh | 97
funasr/layers/inversible_interface.py | 2
funasr/models/e2e_sa_asr.py | 4
funasr/models/encoder/mfcca_encoder.py | 120
egs/aishell2/transformer/local/prepare_data.sh | 6
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py | 9
egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer_sv.py | 4
funasr/models/frontend/fused.py | 2
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py | 8
egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/demo.py | 0
funasr/models/target_delay_transformer.py | 5
egs/aishell2/transformer/utils/compute_fbank.sh | 5
funasr/models/frontend/s3prl.py | 1
egs/aishell2/transformer/utils/compute_wer.py | 4
funasr/models/e2e_uni_asr.py | 11
egs/aishell2/conformer/conf/train_asr_conformer.yaml | 23
egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml | 23
funasr/models/specaug/abs_specaug.py | 2
funasr/tasks/sv.py | 10
egs/aishell/transformer/local/download_and_untar.sh | 105
egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml | 29
funasr/tasks/asr.py | 368 -
egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml | 39
egs/aishell2/conformer/run.sh | 108
funasr/models/frontend/wav_frontend.py | 2
egs/librispeech/conformer/local/data_prep.sh | 0
/dev/null | 55
funasr/build_utils/build_model.py | 25
funasr/bin/tp_infer.py | 120
funasr/models/frontend/default.py | 4
funasr/models/predictor/cif.py | 5
funasr/bin/asr_inference_launch.py | 1724 ++++++
177 files changed, 13,713 insertions(+), 2,425 deletions(-)
diff --git a/docs/academic_recipe/asr_recipe.md b/docs/academic_recipe/asr_recipe.md
index f82a6fe..128b85c 100644
--- a/docs/academic_recipe/asr_recipe.md
+++ b/docs/academic_recipe/asr_recipe.md
@@ -1,21 +1,35 @@
# Speech Recognition
-Here we take "Training a paraformer model from scratch using the AISHELL-1 dataset" as an example to introduce how to use FunASR. According to this example, users can similarly employ other datasets (such as AISHELL-2 dataset, etc.) to train other models (such as conformer, transformer, etc.).
+In FunASR, we provide several ASR benchmarks, such as AISHLL, Librispeech, WenetSpeech, while different model architectures are supported, including conformer, paraformer, uniasr.
-## Overall Introduction
+## Quick Start
+After downloaded and installed FunASR, users can use our provided recipes to easily reproduce the relevant experimental results. Here we take "paraformer on AISHELL-1" as an example.
+
+First, move to the corresponding dictionary of the AISHELL-1 paraformer example.
+```sh
+cd egs/aishell/paraformer
+```
+Then you can directly start the recipe as follows:
+```sh
+conda activate funasr
+. ./run.sh
+```
+The training log files are saved in `exp/*_train_*/log/train.log.*` and the inference results are saved in `exp/*_train_*/decode_asr_*`.
+
+## Introduction
We provide a recipe `egs/aishell/paraformer/run.sh` for training a paraformer model on AISHELL-1 dataset. This recipe consists of five stages, supporting training on multiple GPUs and decoding by CPU or GPU. Before introducing each stage in detail, we first explain several parameters which should be set by users.
- `CUDA_VISIBLE_DEVICES`: visible gpu list
- `gpu_num`: the number of GPUs used for training
- `gpu_inference`: whether to use GPUs for decoding
- `njob`: for CPU decoding, indicating the total number of CPU jobs; for GPU decoding, indicating the number of jobs on each GPU
-- `data_aishell`: the raw path of AISHELL-1 dataset
+- `raw_data`: the raw path of AISHELL-1 dataset
- `feats_dir`: the path for saving processed data
- `nj`: the number of jobs for data preparation
- `speed_perturb`: the range of speech perturbed
- `exp_dir`: the path for saving experimental results
- `tag`: the suffix of experimental result directory
-## Stage 0: Data preparation
-This stage processes raw AISHELL-1 dataset `$data_aishell` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$data_aishell`. The examples of `wav.scp` and `text` are as follows:
+### Stage 0: Data preparation
+This stage processes raw AISHELL-1 dataset `$raw_data` and generates the corresponding `wav.scp` and `text` in `$feats_dir/data/xxx`. `xxx` means `train/dev/test`. Here we assume users have already downloaded AISHELL-1 dataset. If not, users can download data [here](https://www.openslr.org/33/) and set the path for `$raw_data`. The examples of `wav.scp` and `text` are as follows:
* `wav.scp`
```
BAC009S0002W0122 /nfs/ASR_DATA/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
@@ -32,30 +46,10 @@
```
These two files both have two columns, while the first column is wav ids and the second column is the corresponding wav paths/label tokens.
-## Stage 1: Feature Generation
-This stage extracts FBank features from `wav.scp` and apply speed perturbation as data augmentation according to `speed_perturb`. Users can set `nj` to control the number of jobs for feature generation. The generated features are saved in `$feats_dir/dump/xxx/ark` and the corresponding `feats.scp` files are saved as `$feats_dir/dump/xxx/feats.scp`. An example of `feats.scp` can be seen as follows:
-* `feats.scp`
-```
-...
-BAC009S0002W0122_sp0.9 /nfs/funasr_data/aishell-1/dump/fbank/train/ark/feats.16.ark:592751055
-...
-```
-Note that samples in this file have already been shuffled randomly. This file contains two columns. The first column is wav ids while the second column is kaldi-ark feature paths. Besides, `speech_shape` and `text_shape` are also generated in this stage, denoting the speech feature shape and text length of each sample. The examples are shown as follows:
-* `speech_shape`
-```
-...
-BAC009S0002W0122_sp0.9 665,80
-...
-```
-* `text_shape`
-```
-...
-BAC009S0002W0122_sp0.9 15
-...
-```
-These two files have two columns. The first column is wav ids and the second column is the corresponding speech feature shape and text length.
+### Stage 1: Feature and CMVN Generation
+This stage computes CMVN based on `train` dataset, which is used in the following stages. Users can set `nj` to control the number of jobs for computing CMVN. The generated CMVN file is saved as `$feats_dir/data/train/cmvn/cmvn.mvn`.
-## Stage 2: Dictionary Preparation
+### Stage 2: Dictionary Preparation
This stage processes the dictionary, which is used as a mapping between label characters and integer indices during ASR training. The processed dictionary file is saved as `$feats_dir/data/$lang_toekn_list/$token_type/tokens.txt`. An example of `tokens.txt` is as follows:
* `tokens.txt`
```
@@ -74,7 +68,9 @@
* `</s>`: indicates the end-of-sentence token
* `<unk>`: indicates the out-of-vocabulary token
-## Stage 3: Training
+### Stage 3: LM Training
+
+### Stage 4: ASR Training
This stage achieves the training of the specified model. To start training, users should manually set `exp_dir`, `CUDA_VISIBLE_DEVICES` and `gpu_num`, which have already been explained above. By default, the best `$keep_nbest_models` checkpoints on validation dataset will be averaged to generate a better model and adopted for decoding.
* DDP Training
@@ -100,7 +96,7 @@
tensorboard --logdir ${exp_dir}/exp/${model_dir}/tensorboard/train
```
-## Stage 4: Decoding
+### Stage 5: Decoding
This stage generates the recognition results and calculates the `CER` to verify the performance of the trained model.
* Mode Selection
@@ -117,7 +113,7 @@
* Performance
-We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` result. The following is an example of `text.cer`:
+We adopt `CER` to verify the performance. The results are in `$exp_dir/exp/$model_dir/$decoding_yaml_name/$average_model_name/$dset`, namely `text.cer` and `text.cer.txt`. `text.cer` saves the comparison between the recognized text and the reference text while `text.cer.txt` saves the final `CER` results. The following is an example of `text.cer`:
* `text.cer`
```
...
@@ -127,3 +123,48 @@
...
```
+## Change settings
+Here we explain how to perform common custom settings, which can help users to modify scripts according to their own needs.
+
+* Training with specified GPUs
+
+For example, if users want to use 2 GPUs with id `2` and `3, users can run the following command:
+```sh
+. ./run.sh --CUDA_VISIBLE_DEVICES "2,3" --gpu_num 2
+```
+
+* Start from/Stop at a specified stage
+
+The recipe includes several stages. Users can start form or stop at any stage. For example, the following command achieves starting from the third stage and stopping at the fifth stage:
+```sh
+. ./run.sh --stage 3 --stop_stage 5
+```
+
+* Change the configuration of the model
+
+The configuration of the model is set in the config file `conf/train_*.yaml`. Specifically, the default encoder configuration of paraformer is as follows:
+```
+encoder: conformer
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4 # number of heads in multi-head attention
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder input layer architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+```
+Users can change the encoder configuration by modify these values. For example, if users want to use an encoder with 16 conformer blocks and each block has 8 attention heads, users just need to change `num_blocks` from 12 to 16 and change `attention_heads` from 4 to 8. Besides, the batch_size, learning rate and other training hyper-parameters are also set in this config file. To change these hyper-parameters, users just need to directly change the corresponding values in this file. For example, the default learning rate is `0.0005`. If users want to change the learning rate to 0.0002, set the value of lr as `lr: 0.0002`.
+
+* Decoding by CPU or GPU
+
+We support CPU and GPU decoding. For CPU decoding, set `gpu_inference=false` and `njob` to specific the total number of CPU jobs. For GPU decoding, first set `gpu_inference=true`. Then set `gpuid_list` to specific which GPUs for decoding and `njob` to specific the number of decoding jobs on each GPU.
\ No newline at end of file
diff --git a/docs/reference/build_task.md b/docs/reference/build_task.md
index be2d1af..2020860 100644
--- a/docs/reference/build_task.md
+++ b/docs/reference/build_task.md
@@ -103,7 +103,7 @@
)
return model
```
-This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `AbsESPnetModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
+This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `FunASRModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
Next, we take `SANMEncoder` as an example to introduce how to use a custom encoder as a part of the model when defining the specified model and the corresponding code can be seen in `funasr/models/encoder/sanm_encoder.py`. For a custom encoder, in addition to inheriting the common encoder class `AbsEncoder`, it is also necessary to define the `forward` function to achieve the forward computation of the `encoder`. After defining the `encoder`, it should also be registered in the `Task`. The corresponding code example can be seen as below:
```python
diff --git a/egs/aishell/conformer/conf/train_asr_conformer.yaml b/egs/aishell/conformer/conf/train_asr_conformer.yaml
index ddf217e..4814ee7 100644
--- a/egs/aishell/conformer/conf/train_asr_conformer.yaml
+++ b/egs/aishell/conformer/conf/train_asr_conformer.yaml
@@ -29,21 +29,27 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -76,5 +82,17 @@
- 40
num_time_mask: 2
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
+
log_interval: 50
-normalize: None
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/conformer/local/download_and_untar.sh b/egs/aishell/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/conformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/conformer/local/prepare_data.sh b/egs/aishell/conformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/conformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/conformer/run.sh b/egs/aishell/conformer/run.sh
index 227b3f2..3c05006 100755
--- a/egs/aishell/conformer/run.sh
+++ b/egs/aishell/conformer/run.sh
@@ -16,22 +16,20 @@
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
@@ -49,7 +47,7 @@
test_sets="dev test"
asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -126,28 +93,26 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
@@ -161,26 +126,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -191,8 +153,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -203,7 +165,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +186,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -244,5 +207,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
index f9a2cdb..1e1acee 100644
--- a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
@@ -30,7 +30,6 @@
require_same_masks: true
mask_dropout: 0
-
# decoder related
decoder: paraformer_decoder_san
decoder_conf:
@@ -42,6 +41,18 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
model: paraformer
model_conf:
ctc_weight: 0.3
@@ -50,15 +61,10 @@
predictor_weight: 1.0
sampling_ratio: 0.4
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -98,6 +104,17 @@
l_order: 1
r_order: 1
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
log_interval: 50
unused_parameters: true
diff --git a/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh b/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh b/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/data2vec_paraformer_finetune/run.sh b/egs/aishell/data2vec_paraformer_finetune/run.sh
index d033ce2..147add1 100755
--- a/egs/aishell/data2vec_paraformer_finetune/run.sh
+++ b/egs/aishell/data2vec_paraformer_finetune/run.sh
@@ -8,33 +8,31 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
+njob=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
@@ -52,7 +50,7 @@
test_sets="dev test"
asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -69,10 +67,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -82,46 +86,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -129,35 +96,33 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
- fi
+ fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
@@ -165,27 +130,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
--init_param ${init_param} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -196,8 +157,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -208,7 +169,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -229,6 +190,8 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -249,4 +212,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
index 5bc5236..32a7b5b 100644
--- a/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
+++ b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
@@ -30,25 +30,28 @@
require_same_masks: true
mask_dropout: 0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 1.0
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# for logger
-log_interval: 50
-
-# minibatch related
-batch_type: length
-batch_bins: 16000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -57,8 +60,6 @@
- cer_ctc
- min
keep_nbest_models: 10
-unused_parameters: true
-normalize: None
# NoamLR is deprecated. Use WarmupLR.
# The following is equivalent setting for NoamLR:
@@ -92,4 +93,20 @@
time_mask_width_range:
- 0
- 40
- num_time_mask: 2
\ No newline at end of file
+ num_time_mask: 2
+
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
+
+log_interval: 50
+unused_parameters: true
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh b/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh b/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/data2vec_transformer_finetune/run.sh b/egs/aishell/data2vec_transformer_finetune/run.sh
index 26222e6..af0b8c1 100755
--- a/egs/aishell/data2vec_transformer_finetune/run.sh
+++ b/egs/aishell/data2vec_transformer_finetune/run.sh
@@ -13,28 +13,26 @@
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
@@ -52,7 +50,7 @@
test_sets="dev test"
asr_config=conf/train_asr_transformer_12e_6d_3072_768.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.cer_ctc.ave_10best.pb
@@ -69,10 +67,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -82,46 +86,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -129,35 +96,33 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
- fi
+ fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
@@ -165,27 +130,24 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
--init_param ${init_param} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -196,8 +158,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -208,7 +170,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -229,6 +191,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -249,4 +212,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml b/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
index 24b2620..9dd3fb3 100644
--- a/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
+++ b/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
model: paraformer
model_conf:
ctc_weight: 0.3
@@ -36,16 +47,12 @@
length_normalized_loss: false
predictor_weight: 1.0
sampling_ratio: 0.4
-
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
+ use_1st_decoder_loss: true
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -78,7 +85,7 @@
- 40
num_time_mask: 2
-predictor: cif_predictor_v2
+predictor: cif_predictor
predictor_conf:
idim: 256
threshold: 1.0
@@ -86,6 +93,17 @@
r_order: 1
tail_threshold: 0.45
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
log_interval: 50
normalize: None
\ No newline at end of file
diff --git a/egs/aishell/paraformer/local/download_and_untar.sh b/egs/aishell/paraformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/paraformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/paraformer/local/prepare_data.sh b/egs/aishell/paraformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/paraformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/paraformer/run.sh b/egs/aishell/paraformer/run.sh
index 53b5f90..39ce85a 100755
--- a/egs/aishell/paraformer/run.sh
+++ b/egs/aishell/paraformer/run.sh
@@ -16,25 +16,23 @@
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -49,7 +47,7 @@
test_sets="dev test"
asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -131,23 +98,21 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
@@ -161,26 +126,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -191,8 +153,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -203,7 +165,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +186,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml b/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
index f51a2ea..f2652e8 100644
--- a/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
+++ b/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model: paraformer_bert
model_conf:
@@ -41,19 +52,10 @@
embed_dims: 768
embeds_loss_weight: 2.0
-
-
-# minibatch related
-#batch_type: length
-#batch_bins: 40000
-batch_type: numel
-batch_bins: 2000000
-num_workers: 16
-
# optimization related
-accum_grad: 4
+accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -92,8 +94,19 @@
threshold: 1.0
l_order: 1
r_order: 1
+ tail_threshold: 0.45
+dataset_conf:
+ data_names: speech,text,embed
+ data_types: sound,text,kaldi_ark
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
log_interval: 50
-normalize: None
-allow_variable_data_keys: true
\ No newline at end of file
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/paraformerbert/local/aishell_data_prep.sh b/egs/aishell/paraformerbert/local/aishell_data_prep.sh
index b6ea36b..83f489b 100755
--- a/egs/aishell/paraformerbert/local/aishell_data_prep.sh
+++ b/egs/aishell/paraformerbert/local/aishell_data_prep.sh
@@ -5,19 +5,20 @@
#. ./path.sh || exit 1;
-if [ $# != 2 ]; then
- echo "Usage: $0 <audio-path> <text-path>"
- echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript"
+if [ $# != 3 ]; then
+ echo "Usage: $0 <audio-path> <text-path> <output-path>"
+ echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
-train_dir=data/local/train
-dev_dir=data/local/dev
-test_dir=data/local/test
-tmp_dir=data/local/tmp
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
@@ -53,12 +54,12 @@
sort -u $dir/transcripts.txt > $dir/text
done
-mkdir -p data/train data/dev data/test
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
- cp $train_dir/$f data/train/$f || exit 1;
- cp $dev_dir/$f data/dev/$f || exit 1;
- cp $test_dir/$f data/test/$f || exit 1;
+ cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+ cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+ cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
diff --git a/egs/aishell/paraformerbert/local/download_and_untar.sh b/egs/aishell/paraformerbert/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/paraformerbert/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/paraformerbert/local/extract_embeds.sh b/egs/aishell/paraformerbert/local/extract_embeds.sh
index 9cf5940..ca0c878 100755
--- a/egs/aishell/paraformerbert/local/extract_embeds.sh
+++ b/egs/aishell/paraformerbert/local/extract_embeds.sh
@@ -3,20 +3,16 @@
stage=1
stop_stage=3
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
-#bert_model_name="chinese-roberta-wwm-ext"
-#bert_model_name="mengzi-bert-base"
raw_dataset_path="../DATA"
-model_path=${bert_model_root}/${bert_model_name}
+nj=64
+model_path=${bert_model_name}
. utils/parse_options.sh || exit 1;
-nj=32
-
for data_set in train dev test;do
- scp=$raw_dataset_path/dump/fbank/${data_set}/text
- local_scp_dir_raw=$raw_dataset_path/embeds/$bert_model_name/${data_set}
+ scp=$raw_dataset_path/data/${data_set}/text
+ local_scp_dir_raw=${raw_dataset_path}/data/embeds/${data_set}
local_scp_dir=$local_scp_dir_raw/split$nj
local_records_dir=$local_scp_dir_raw/ark
@@ -58,6 +54,8 @@
cat ${local_records_dir}/embeds.${JOB}.shape || exit 1;
done > ${local_scp_dir_raw}/embeds.shape
fi
+
+ cp ${local_scp_dir_raw}/embeds.scp ${raw_dataset_path}/data/${data_set}/embeds.scp
done
echo "embeds is in: ${local_scp_dir_raw}"
diff --git a/egs/aishell/paraformerbert/run.sh b/egs/aishell/paraformerbert/run.sh
index 2487eac..5562f15 100755
--- a/egs/aishell/paraformerbert/run.sh
+++ b/egs/aishell/paraformerbert/run.sh
@@ -8,7 +8,7 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
+njob=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
@@ -16,29 +16,26 @@
feats_dir="../DATA" #feature output dictionary, for large data
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
stage=0
-stop_stage=4
+stop_stage=5
skip_extract_embed=false
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -53,7 +50,7 @@
test_sets="dev test"
asr_config=conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -70,10 +67,17 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -83,46 +87,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,29 +102,27 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
if ! "${skip_extract_embed}"; then
echo "extract embeddings..."
local/extract_embeds.sh \
- --bert_model_root ${bert_model_root} \
--bert_model_name ${bert_model_name} \
- --raw_dataset_path ${feats_dir}
+ --raw_dataset_path ${feats_dir} \
+ --nj $nj
fi
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
@@ -172,31 +137,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_data_path_and_name_and_type ${feats_dir}/embeds/${bert_model_name}/${train_set}/embeds.scp,embed,${type} \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --train_shape_file ${feats_dir}/embeds/${bert_model_name}/${train_set}/embeds.shape \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_data_path_and_name_and_type ${feats_dir}/embeds/${bert_model_name}/${valid_set}/embeds.scp,embed,${type} \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
- --valid_shape_file ${feats_dir}/embeds/${bert_model_name}/${valid_set}/embeds.shape \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text,embeds.scp" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --allow_variable_data_keys true \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -207,8 +164,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -219,7 +176,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -240,6 +197,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -260,5 +218,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/conf/train_asr_transformer.yaml b/egs/aishell/transformer/conf/train_asr_transformer.yaml
index ce987e7..b386565 100644
--- a/egs/aishell/transformer/conf/train_asr_transformer.yaml
+++ b/egs/aishell/transformer/conf/train_asr_transformer.yaml
@@ -23,22 +23,28 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# minibatch related
-batch_type: length
-batch_bins: 32000
-num_workers: 8
-
# optimization related
accum_grad: 1
grad_clip: 5
-patience: 3
-max_epoch: 20
+patience: none
+max_epoch: 60
val_scheduler_criterion:
- valid
- acc
@@ -66,5 +72,17 @@
scheduler_conf:
warmup_steps: 25000
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
+
log_interval: 50
normalize: None
diff --git a/egs/aishell/transformer/local/download_and_untar.sh b/egs/aishell/transformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/transformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/transformer/local/prepare_data.sh b/egs/aishell/transformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/transformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/transformer/path.sh b/egs/aishell/transformer/path.sh
index 7972642..b4064e1 100755
--- a/egs/aishell/transformer/path.sh
+++ b/egs/aishell/transformer/path.sh
@@ -3,3 +3,5 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
+
+export OMP_NUM_THREADS=1
diff --git a/egs/aishell/transformer/run.sh b/egs/aishell/transformer/run.sh
index f66a338..3db8a08 100755
--- a/egs/aishell/transformer/run.sh
+++ b/egs/aishell/transformer/run.sh
@@ -8,33 +8,31 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=8
+njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -48,8 +46,8 @@
valid_set=dev
test_sets="dev test"
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+asr_config=conf/train_asr_transformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -126,34 +93,32 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
- fi
+ fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
@@ -161,26 +126,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -191,8 +153,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -203,7 +165,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +186,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -244,4 +207,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/utils/cmvn_converter.py b/egs/aishell/transformer/utils/cmvn_converter.py
index cb978af..d405d12 100644
--- a/egs/aishell/transformer/utils/cmvn_converter.py
+++ b/egs/aishell/transformer/utils/cmvn_converter.py
@@ -9,16 +9,14 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--cmvn-json",
- "-c",
+ "--cmvn_json",
default=False,
required=True,
type=str,
help="cmvn json file",
)
parser.add_argument(
- "--am-mvn",
- "-a",
+ "--am_mvn",
default=False,
required=True,
type=str,
diff --git a/egs/aishell/transformer/utils/combine_cmvn_file.py b/egs/aishell/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
import argparse
import json
+import os
+
import numpy as np
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dim",
)
parser.add_argument(
- "--cmvn-dir",
- "-c",
+ "--cmvn_dir",
default=False,
required=True,
type=str,
@@ -25,15 +26,13 @@
parser.add_argument(
"--nj",
- "-n",
default=1,
required=True,
type=int,
- help="num of cmvn file",
+ help="num of cmvn files",
)
parser.add_argument(
- "--output-dir",
- "-o",
+ "--output_dir",
default=False,
required=True,
type=str,
@@ -46,14 +45,14 @@
parser = get_parser()
args = parser.parse_args()
- total_means = np.zeros(args.dims)
- total_vars = np.zeros(args.dims)
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
total_frames = 0
- cmvn_file = args.output_dir + "/cmvn.json"
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
- for i in range(1, args.nj+1):
- with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
diff --git a/egs/aishell/transformer/utils/compute_cmvn.py b/egs/aishell/transformer/utils/compute_cmvn.py
index 2b96e26..949cc08 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.py
+++ b/egs/aishell/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
import argparse
-import numpy as np
import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
def get_parser():
@@ -11,55 +13,83 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dimension",
)
parser.add_argument(
- "--ark-file",
- "-a",
+ "--wav_path",
default=False,
required=True,
type=str,
- help="fbank ark file",
+ help="the path of wav scps",
)
parser.add_argument(
- "--ark-index",
- "-i",
+ "--idx",
default=1,
required=True,
type=int,
- help="ark index",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- default=False,
- required=True,
- type=str,
- help="output dir",
+ help="index",
)
return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
def main():
parser = get_parser()
args = parser.parse_args()
- ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
- cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
- mean_stats = np.zeros(args.dims)
- var_stats = np.zeros(args.dims)
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
total_frames = 0
- with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
- for key, mat in ark_reader:
- mean_stats += np.sum(mat, axis=0)
- var_stats += np.sum(np.square(mat), axis=0)
- total_frames += mat.shape[0]
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell/transformer/utils/compute_cmvn.sh b/egs/aishell/transformer/utils/compute_cmvn.sh
index 12173ee..7663df9 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell/transformer/utils/compute_cmvn.sh
@@ -11,15 +11,24 @@
. utils/parse_options.sh || exit 1;
fbankdir=$1
-logdir=$2
-output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
-mkdir -p ${logdir}
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
+logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
- python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
- || exit 1;
+ python utils/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --idx JOB
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/conformer/conf/train_asr_conformer.yaml b/egs/aishell2/conformer/conf/train_asr_conformer.yaml
index 02fc5a8..8183378 100644
--- a/egs/aishell2/conformer/conf/train_asr_conformer.yaml
+++ b/egs/aishell2/conformer/conf/train_asr_conformer.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
@@ -39,7 +50,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -72,10 +83,9 @@
- 40
num_time_mask: 2
-log_interval: 50
-normalize: None
-
dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
@@ -83,4 +93,7 @@
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/aishell2/conformer/local/prepare_data.sh b/egs/aishell2/conformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/conformer/local/prepare_data.sh
+++ b/egs/aishell2/conformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/conformer/run.sh b/egs/aishell2/conformer/run.sh
index f9ea69a..ae57431 100755
--- a/egs/aishell2/conformer/run.sh
+++ b/egs/aishell2/conformer/run.sh
@@ -9,27 +9,24 @@
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
tr_dir=
@@ -51,13 +48,13 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
@@ -73,7 +70,7 @@
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
done
@@ -83,51 +80,14 @@
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -140,23 +100,21 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
@@ -170,21 +128,24 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
--token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -195,8 +156,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -207,7 +168,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +189,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
index 4052774..767d8ba 100644
--- a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
+++ b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
@@ -2,47 +2,52 @@
# encoder related
encoder: data2vec_encoder
encoder_conf:
- extractor_mode: layer_norm
- encoder_layerdrop: 0.05
- dropout_input: 0.0
- dropout_features: 0.0
- feature_grad_mult: 1.0
- encoder_embed_dim: 768
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.05
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 1.0
+ encoder_embed_dim: 768
- mask_prob: 0.65
- mask_length: 10
+ mask_prob: 0.65
+ mask_length: 10
- loss_beta: 0
- loss_scale: null
+ loss_beta: 0
+ loss_scale: null
- instance_norm_target_layer: true
- average_top_k_layers: 8
+ instance_norm_target_layer: true
+ average_top_k_layers: 8
- pos_conv_depth: 5
- conv_pos: 95
+ pos_conv_depth: 5
+ conv_pos: 95
- ema_decay: 0.999
- ema_end_decay: 0.9999
- ema_anneal_end_step: 30000
- ema_transformer_only: true
- ema_layers_only: true
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: true
- require_same_masks: true
- mask_dropout: 0
+ require_same_masks: true
+ mask_dropout: 0
-log_interval: 50
-normalize: None
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
-# minibatch related
-batch_type: length
-batch_bins: 64000
-num_workers: 16
+model: data2vec
# optimization related
accum_grad: 1
grad_clip: 5
patience: none
-max_epoch: 600
+max_epoch: 1800
val_scheduler_criterion:
- valid
- acc
@@ -67,8 +72,8 @@
# for dataset
dataset_conf:
batch_mode: clipping
- data_names: speech,none
- data_types: kaldi_ark,none
+ data_names: speech
+ data_types: sound
shuffle: true
shuffle_conf:
shuffle_size: 12800
@@ -76,4 +81,7 @@
batch_conf:
batch_type: token
batch_size: 64000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/run.sh b/egs/aishell2/data2vec_pretrain/run.sh
index eceb183..9334a4b 100755
--- a/egs/aishell2/data2vec_pretrain/run.sh
+++ b/egs/aishell2/data2vec_pretrain/run.sh
@@ -7,24 +7,21 @@
gpu_num=8
count=1
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
stage=0
-stop_stage=4
+stop_stage=3
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
tr_dir=
@@ -45,68 +42,31 @@
valid_set=dev_ios
asr_config=conf/train_pretrain_transformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml) _${lang}_${token_type}_${tag}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
- done
+ done
# Normalize text to capital letters
- for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+ for x in train dev_ios test_ios; do
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -114,23 +74,15 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
# Training Stage
world_size=$gpu_num # run on one machine
@@ -149,12 +101,17 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- data2vec_train.py \
+ train.py \
+ --task_name pretrain \
--gpu_id $gpu_id \
--use_preprocessor true \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--dataset_type $dataset_type \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
diff --git a/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml b/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
index 450f71a..3ecf44e 100644
--- a/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
+++ b/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model: paraformer
model_conf:
@@ -42,7 +53,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -82,10 +93,9 @@
l_order: 1
r_order: 1
-log_interval: 50
-normalize: None
-
dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
@@ -93,4 +103,7 @@
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/aishell2/paraformer/local/prepare_data.sh b/egs/aishell2/paraformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/paraformer/local/prepare_data.sh
+++ b/egs/aishell2/paraformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/paraformer/run.sh b/egs/aishell2/paraformer/run.sh
index e1ea4fe..83e49d0 100755
--- a/egs/aishell2/paraformer/run.sh
+++ b/egs/aishell2/paraformer/run.sh
@@ -9,27 +9,24 @@
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=1
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
tr_dir=
@@ -51,7 +48,7 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -73,61 +70,24 @@
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
- done
+ done
# Normalize text to capital letters
- for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+ for x in train dev_ios test_ios; do
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,28 +95,26 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
@@ -170,33 +128,36 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
--token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
} &
- done
- wait
+ done
+ wait
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -207,7 +168,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +189,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml b/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
index 19f123e..8968d2d 100644
--- a/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
+++ b/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model: paraformer_bert
model_conf:
@@ -36,7 +47,7 @@
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
predictor_weight: 1.0
- glat_context_p: 0.4
+ sampling_ratio: 0.4
embeds_id: 3
embed_dims: 768
embeds_loss_weight: 2.0
@@ -45,7 +56,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -78,24 +89,24 @@
- 40
num_time_mask: 2
-predictor: cif_predictor_sanm
+predictor: cif_predictor
predictor_conf:
idim: 320
threshold: 1.0
l_order: 1
r_order: 1
-log_interval: 50
-normalize: None
-
dataset_conf:
data_names: speech,text,embed
- data_types: kaldi_ark,text,kaldi_ark
+ data_types: sound,text,kaldi_ark
shuffle: True
shuffle_conf:
- shuffle_size: 10240
+ shuffle_size: 2048
sort_size: 500
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/paraformerbert/local/extract_embeds.sh b/egs/aishell2/paraformerbert/local/extract_embeds.sh
index 5f45ff3..d7dd4f2 100755
--- a/egs/aishell2/paraformerbert/local/extract_embeds.sh
+++ b/egs/aishell2/paraformerbert/local/extract_embeds.sh
@@ -3,20 +3,16 @@
stage=1
stop_stage=3
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
-#bert_model_name="chinese-roberta-wwm-ext"
-#bert_model_name="mengzi-bert-base"
raw_dataset_path="../DATA"
-model_path=${bert_model_root}/${bert_model_name}
+nj=64
+model_path=${bert_model_name}
. utils/parse_options.sh || exit 1;
-nj=100
-
-for data_set in train dev_ios test_ios;do
- scp=$raw_dataset_path/dump/fbank/${data_set}/text
- local_scp_dir_raw=$raw_dataset_path/embeds/$bert_model_name/${data_set}
+for data_set in train dev_ios;do
+ scp=$raw_dataset_path/data/${data_set}/text
+ local_scp_dir_raw=${raw_dataset_path}/data/embeds/${data_set}
local_scp_dir=$local_scp_dir_raw/split$nj
local_records_dir=$local_scp_dir_raw/ark
@@ -31,7 +27,7 @@
utils/split_scp.pl $scp ${split_scps}
- for num in {0..24};do
+ for num in {0..7};do
tmp=`expr $num \* 4`
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
@@ -41,20 +37,9 @@
{
beg=0
gpu=`expr $beg + $idx`
- echo $local_scp_dir_raw/log/log.${JOB}
- python tools/extract_embeds.py $local_scp_dir/text.$JOB.txt ${local_records_dir}/embeds.${JOB}.ark ${local_records_dir}/embeds.${JOB}.scp ${local_records_dir}/embeds.${JOB}.shape ${gpu} ${model_path} &> $local_scp_dir_raw/log/log.${JOB}
+ echo ${local_scp_dir}/log.${JOB}
+ python utils/extract_embeds.py $local_scp_dir/data.$JOB.text ${local_records_dir}/embeds.${JOB}.ark ${local_records_dir}/embeds.${JOB}.scp ${local_records_dir}/embeds.${JOB}.shape ${gpu} ${model_path} &> ${local_scp_dir}/log.${JOB}
} &
- done
- wait
- fi
-
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
- for idx in {1..4}; do
- JOB=`expr $tmp + $idx`
- echo "upload jobid=$JOB"
- {
- hadoop fs -put -f ${local_records_dir}/embeds.${JOB}.ark ${odps_des_feature_dir}/embeds.${JOB}.ark
- } &
done
wait
fi
@@ -69,6 +54,8 @@
cat ${local_records_dir}/embeds.${JOB}.shape || exit 1;
done > ${local_scp_dir_raw}/embeds.shape
fi
+
+ cp ${local_scp_dir_raw}/embeds.scp ${raw_dataset_path}/data/${data_set}/embeds.scp
done
echo "embeds is in: ${local_scp_dir_raw}"
diff --git a/egs/aishell2/paraformerbert/local/prepare_data.sh b/egs/aishell2/paraformerbert/local/prepare_data.sh
index 801dbe5..77791f9 100755
--- a/egs/aishell2/paraformerbert/local/prepare_data.sh
+++ b/egs/aishell2/paraformerbert/local/prepare_data.sh
@@ -17,7 +17,6 @@
fi
corpus=$1
-#dict_dir=$2
tmp=$2
dir=$3
@@ -35,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/paraformerbert/run.sh b/egs/aishell2/paraformerbert/run.sh
index 239a7e3..a5e5ba9 100755
--- a/egs/aishell2/paraformerbert/run.sh
+++ b/egs/aishell2/paraformerbert/run.sh
@@ -8,32 +8,28 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=tools/run.pl
+njob=1
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
stage=0
stop_stage=5
skip_extract_embed=false
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
tr_dir=
@@ -55,7 +51,7 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -75,97 +71,59 @@
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# For training set
- local/prepare_data.sh ${tr_dir} data/local/train data/train || exit 1;
+ local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
- local/prepare_data.sh ${dev_tst_dir}/${x}/dev data/local/dev_${x,,} data/dev_${x,,} || exit 1;
- local/prepare_data.sh ${dev_tst_dir}/${x}/test data/local/test_${x,,} data/test_${x,,} || exit 1;
- done
+ for x in iOS; do
+ local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
+ local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
+ done
# Normalize text to capital letters
- for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
- mv data/${x}/text data/${x}/text.org
- paste <(cut -f 1 data/${x}/text.org) <(cut -f 2 data/${x}/text.org | tr '[:lower:]' '[:upper:]') \
- > data/${x}/text
- tools/text2token.py -n 1 -s 1 data/${x}/text > data/${x}/text.org
- mv data/${x}/text.org data/${x}/text
+ for x in train dev_ios test_ios; do
+ mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
+ | tr 'A-Z' 'a-z' | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- data/train exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- data/dev_${x} exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- data/test_${x} exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
- mkdir -p data/${lang}_token_list/char/
-
+ mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p asr_stats_fbank_zh_char/${train_set}
- mkdir -p asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
if ! "${skip_extract_embed}"; then
echo "extract embeddings..."
local/extract_embeds.sh \
- --bert_model_root ${bert_model_root} \
--bert_model_name ${bert_model_name} \
- --raw_dataset_path ${feats_dir}
+ --raw_dataset_path ${feats_dir} \
+ --nj $nj
fi
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
@@ -180,22 +138,24 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
- --token_type $token_type \
+ --token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data_bert.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data_bert.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text,embeds.scp" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --allow_variable_data_keys true \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -206,8 +166,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -218,7 +178,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -239,6 +199,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -259,5 +220,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell2/transformer/conf/train_asr_transformer.yaml b/egs/aishell2/transformer/conf/train_asr_transformer.yaml
index 3e2172d..1b76e2a 100644
--- a/egs/aishell2/transformer/conf/train_asr_transformer.yaml
+++ b/egs/aishell2/transformer/conf/train_asr_transformer.yaml
@@ -23,6 +23,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
@@ -33,7 +44,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -66,10 +77,9 @@
- 40
num_time_mask: 2
-log_interval: 50
-normalize: None
-
dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
@@ -77,4 +87,7 @@
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/transformer/local/prepare_data.sh b/egs/aishell2/transformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/transformer/local/prepare_data.sh
+++ b/egs/aishell2/transformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/transformer/run.sh b/egs/aishell2/transformer/run.sh
index 6f2dd4d..6e5c82a 100755
--- a/egs/aishell2/transformer/run.sh
+++ b/egs/aishell2/transformer/run.sh
@@ -9,27 +9,24 @@
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
stage=0
-stop_stage=4
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
tr_dir=
@@ -51,13 +48,13 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_transformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
@@ -73,61 +70,24 @@
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
- done
+ done
# Normalize text to capital letters
for x in train dev_ios test_ios; do
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,28 +95,26 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
@@ -170,21 +128,24 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
--token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -195,8 +156,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5 Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -207,7 +168,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +189,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/transformer/utils/cmvn_converter.py b/egs/aishell2/transformer/utils/cmvn_converter.py
new file mode 100644
index 0000000..d405d12
--- /dev/null
+++ b/egs/aishell2/transformer/utils/cmvn_converter.py
@@ -0,0 +1,51 @@
+import argparse
+import json
+import numpy as np
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="cmvn converter",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--cmvn_json",
+ default=False,
+ required=True,
+ type=str,
+ help="cmvn json file",
+ )
+ parser.add_argument(
+ "--am_mvn",
+ default=False,
+ required=True,
+ type=str,
+ help="am mvn file",
+ )
+ return parser
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ with open(args.cmvn_json, "r") as fin:
+ cmvn_dict = json.load(fin)
+
+ mean_stats = np.array(cmvn_dict["mean_stats"])
+ var_stats = np.array(cmvn_dict["var_stats"])
+ total_frame = np.array(cmvn_dict["total_frames"])
+
+ mean = -1.0 * mean_stats / total_frame
+ var = 1.0 / np.sqrt(var_stats / total_frame - mean * mean)
+ dims = mean.shape[0]
+ with open(args.am_mvn, 'w') as fout:
+ fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
+ mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+ fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
+ fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
+ var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+ fout.write("<LearnRateCoef> 0 " + var_str + '\n')
+ fout.write("</Nnet>" + '\n')
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell2/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
import argparse
import json
+import os
+
import numpy as np
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dim",
)
parser.add_argument(
- "--cmvn-dir",
- "-c",
+ "--cmvn_dir",
default=False,
required=True,
type=str,
@@ -25,15 +26,13 @@
parser.add_argument(
"--nj",
- "-n",
default=1,
required=True,
type=int,
- help="num of cmvn file",
+ help="num of cmvn files",
)
parser.add_argument(
- "--output-dir",
- "-o",
+ "--output_dir",
default=False,
required=True,
type=str,
@@ -46,14 +45,14 @@
parser = get_parser()
args = parser.parse_args()
- total_means = np.zeros(args.dims)
- total_vars = np.zeros(args.dims)
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
total_frames = 0
- cmvn_file = args.output_dir + "/cmvn.json"
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
- for i in range(1, args.nj+1):
- with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.py b/egs/aishell2/transformer/utils/compute_cmvn.py
index 2b96e26..949cc08 100755
--- a/egs/aishell2/transformer/utils/compute_cmvn.py
+++ b/egs/aishell2/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
import argparse
-import numpy as np
import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
def get_parser():
@@ -11,55 +13,83 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dimension",
)
parser.add_argument(
- "--ark-file",
- "-a",
+ "--wav_path",
default=False,
required=True,
type=str,
- help="fbank ark file",
+ help="the path of wav scps",
)
parser.add_argument(
- "--ark-index",
- "-i",
+ "--idx",
default=1,
required=True,
type=int,
- help="ark index",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- default=False,
- required=True,
- type=str,
- help="output dir",
+ help="index",
)
return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
def main():
parser = get_parser()
args = parser.parse_args()
- ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
- cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
- mean_stats = np.zeros(args.dims)
- var_stats = np.zeros(args.dims)
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
total_frames = 0
- with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
- for key, mat in ark_reader:
- mean_stats += np.sum(mat, axis=0)
- var_stats += np.sum(np.square(mat), axis=0)
- total_frames += mat.shape[0]
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.sh b/egs/aishell2/transformer/utils/compute_cmvn.sh
index 12173ee..7663df9 100755
--- a/egs/aishell2/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell2/transformer/utils/compute_cmvn.sh
@@ -11,15 +11,24 @@
. utils/parse_options.sh || exit 1;
fbankdir=$1
-logdir=$2
-output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
-mkdir -p ${logdir}
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
+logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
- python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
- || exit 1;
+ python utils/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --idx JOB
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/transformer/utils/compute_fbank.py b/egs/aishell2/transformer/utils/compute_fbank.py
index d03b5a8..9c3904f 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.py
+++ b/egs/aishell2/transformer/utils/compute_fbank.py
@@ -14,7 +14,8 @@
frame_shift=10,
dither=0.0,
resample_rate=16000,
- speed=1.0):
+ speed=1.0,
+ window_type="hamming"):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
@@ -33,7 +34,7 @@
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
- window_type='hamming',
+ window_type=window_type,
sample_frequency=resample_rate)
return mat.numpy()
@@ -68,6 +69,13 @@
help="feature dims",
)
parser.add_argument(
+ "--max-lengths",
+ "-m",
+ default=1500,
+ type=int,
+ help="max frame numbers",
+ )
+ parser.add_argument(
"--sample-frequency",
"-s",
default=16000,
@@ -96,6 +104,13 @@
required=True,
type=str,
help="output dir",
+ )
+ parser.add_argument(
+ "--window-type",
+ default="hamming",
+ required=False,
+ type=str,
+ help="window type"
)
return parser
@@ -131,10 +146,13 @@
fbank = compute_fbank(wav_file,
num_mel_bins=args.dims,
resample_rate=args.sample_frequency,
- speed=float(speed)
+ speed=float(speed),
+ window_type=args.window_type
)
feats_dims = fbank.shape[1]
feats_lens = fbank.shape[0]
+ if feats_lens >= args.max_lengths:
+ continue
txt_lens = len(txt)
if speed == "1.0":
wav_id_sp = wav_id
diff --git a/egs/aishell2/transformer/utils/compute_fbank.sh b/egs/aishell2/transformer/utils/compute_fbank.sh
index 92a4fe6..8704b31 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.sh
+++ b/egs/aishell2/transformer/utils/compute_fbank.sh
@@ -9,6 +9,8 @@
feats_dim=80
sample_frequency=16000
speed_perturb="1.0"
+window_type="hamming"
+max_lengths=1500
echo "$0 $@"
@@ -29,7 +31,8 @@
$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
- -d $feats_dim -s $sample_frequency -p ${speed_perturb} -a JOB -o ${fbankdir} \
+ -d $feats_dim -s $sample_frequency -m ${max_lengths} -p ${speed_perturb} -a JOB -o ${fbankdir} \
+ --window-type ${window_type} \
|| exit 1;
for n in $(seq $nj); do
diff --git a/egs/aishell2/transformer/utils/compute_wer.py b/egs/aishell2/transformer/utils/compute_wer.py
index 349a3f6..26a9f49 100755
--- a/egs/aishell2/transformer/utils/compute_wer.py
+++ b/egs/aishell2/transformer/utils/compute_wer.py
@@ -45,8 +45,8 @@
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
- cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
- cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+ cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+ cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
diff --git a/egs/aishell2/transformer/utils/download_model.py b/egs/aishell2/transformer/utils/download_model.py
new file mode 100755
index 0000000..70ea179
--- /dev/null
+++ b/egs/aishell2/transformer/utils/download_model.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+import argparse
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="download model configs",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("--model_name",
+ type=str,
+ default="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
+ help="model name in ModelScope")
+ args = parser.parse_args()
+
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.model_name)
diff --git a/egs/aishell2/transformer/utils/fix_data.sh b/egs/aishell2/transformer/utils/fix_data.sh
index 32cdde5..b1a2bb8 100755
--- a/egs/aishell2/transformer/utils/fix_data.sh
+++ b/egs/aishell2/transformer/utils/fix_data.sh
@@ -28,8 +28,8 @@
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak > ${data_dir}/wav.scp
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak
diff --git a/egs/aishell2/transformer/utils/fix_data_feat.sh b/egs/aishell2/transformer/utils/fix_data_feat.sh
index 2c92d7f..84eea36 100755
--- a/egs/aishell2/transformer/utils/fix_data_feat.sh
+++ b/egs/aishell2/transformer/utils/fix_data_feat.sh
@@ -40,10 +40,10 @@
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak > ${data_dir}/feats.scp
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak > ${data_dir}/speech_shape
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak > ${data_dir}/text_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
index 421d7df..aa48b2d 100644
--- a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
+++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
@@ -43,7 +43,6 @@
pooling_type: statistic
num_nodes_resnet1: 256
num_nodes_last_layer: 256
- batchnorm_momentum: 0.5
# decoder related
decoder: sa_decoder
diff --git a/egs/librispeech/conformer/conf/decode_asr_transformer.yaml b/egs/librispeech/conformer/conf/decode_asr_transformer.yaml
deleted file mode 100644
index a147fa7..0000000
--- a/egs/librispeech/conformer/conf/decode_asr_transformer.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-beam_size: 10
-penalty: 0.0
-maxlenratio: 0.0
-minlenratio: 0.0
-ctc_weight: 0.5
-lm_weight: 0.7
diff --git a/egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml b/egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml
new file mode 100644
index 0000000..8f7c75d
--- /dev/null
+++ b/egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml
@@ -0,0 +1,6 @@
+beam_size: 5
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.3
+lm_weight: 0.0
diff --git a/egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam60.yaml b/egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam60.yaml
new file mode 100644
index 0000000..0ebb9af
--- /dev/null
+++ b/egs/librispeech/conformer/conf/decode_asr_transformer_ctc0.3_beam60.yaml
@@ -0,0 +1,6 @@
+beam_size: 60
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.3
+lm_weight: 0.0
diff --git a/egs/librispeech/conformer/conf/train_asr_conformer.yaml b/egs/librispeech/conformer/conf/train_asr_conformer.yaml
index 68b127f..2bd3db4 100644
--- a/egs/librispeech/conformer/conf/train_asr_conformer.yaml
+++ b/egs/librispeech/conformer/conf/train_asr_conformer.yaml
@@ -27,13 +27,25 @@
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
accum_grad: 2
-max_epoch: 50
+max_epoch: 150
patience: none
init: none
best_model_criterion:
diff --git a/egs/librispeech/conformer/local/data_prep_librispeech.sh b/egs/librispeech/conformer/local/data_prep.sh
similarity index 100%
rename from egs/librispeech/conformer/local/data_prep_librispeech.sh
rename to egs/librispeech/conformer/local/data_prep.sh
diff --git a/egs/librispeech/conformer/local/download_and_untar.sh b/egs/librispeech/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech/conformer/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+ echo " train-clean-100, train-clean-360, train-other-500."
+ exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0
+fi
+
+
+# sizes of the archive files in bytes. This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release. Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+ size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tar.gz
+ else
+ echo "$data/$part.tar.gz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+ if ! which wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1
+ fi
+ full_url=$url/$part.tar.gz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ if ! wget -P $data --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1
+ fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+ echo "$0: error un-tarring archive $data/$part.tar.gz"
+ exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+ rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech/conformer/local/spm_encode.py b/egs/librispeech/conformer/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech/conformer/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True,
+ help="sentencepiece model to use for encoding")
+ parser.add_argument("--inputs", nargs="+", default=['-'],
+ help="input files to filter/encode")
+ parser.add_argument("--outputs", nargs="+", default=['-'],
+ help="path to save encoded outputs")
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+ parser.add_argument("--min-len", type=int, metavar="N",
+ help="filter sentence pairs with fewer than N tokens")
+ parser.add_argument("--max-len", type=int, metavar="N",
+ help="filter sentence pairs with more than N tokens")
+ args = parser.parse_args()
+
+ assert len(args.inputs) == len(args.outputs), \
+ "number of input and output paths should match"
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.output_format == "piece":
+ def encode(l):
+ return sp.EncodeAsPieces(l)
+ elif args.output_format == "id":
+ def encode(l):
+ return list(map(str, sp.EncodeAsIds(l)))
+ else:
+ raise NotImplementedError
+
+ if args.min_len is not None or args.max_len is not None:
+ def valid(line):
+ return (
+ (args.min_len is None or len(line) >= args.min_len) and
+ (args.max_len is None or len(line) <= args.max_len)
+ )
+ else:
+ def valid(lines):
+ return True
+
+ with contextlib.ExitStack() as stack:
+ inputs = [
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-" else sys.stdin
+ for input in args.inputs
+ ]
+ outputs = [
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-" else sys.stdout
+ for output in args.outputs
+ ]
+
+ stats = {
+ "num_empty": 0,
+ "num_filtered": 0,
+ }
+
+ def encode_line(line):
+ line = line.strip()
+ if len(line) > 0:
+ line = encode(line)
+ if valid(line):
+ return line
+ else:
+ stats["num_filtered"] += 1
+ else:
+ stats["num_empty"] += 1
+ return None
+
+ for i, lines in enumerate(zip(*inputs), start=1):
+ enc_lines = list(map(encode_line, lines))
+ if not any(enc_line is None for enc_line in enc_lines):
+ for enc_line, output_h in zip(enc_lines, outputs):
+ print(" ".join(enc_line), file=output_h)
+ if i % 10000 == 0:
+ print("processed {} lines".format(i), file=sys.stderr)
+
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/conformer/local/spm_train.py b/egs/librispeech/conformer/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech/conformer/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/librispeech/conformer/run.sh b/egs/librispeech/conformer/run.sh
index 93d1b46..b44dad3 100755
--- a/egs/librispeech/conformer/run.sh
+++ b/egs/librispeech/conformer/run.sh
@@ -16,30 +16,27 @@
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=en
-dumpdir=dump/fbank
-feats_type=fbank
token_type=bpe
-dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=5
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_librispeech=
+raw_data=
+data_url=www.openslr.org/resources/12
# bpe model
nbpe=5000
bpemode=unigram
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -54,12 +51,11 @@
test_sets="test_clean test_other dev_clean dev_other"
asr_config=conf/train_asr_conformer.yaml
-#asr_config=conf/train_asr_conformer_uttnorm.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
-inference_config=conf/decode_asr_transformer.yaml
-#inference_config=conf/decode_asr_transformer_beam60_ctc0.3.yaml
-inference_asr_model=valid.acc.ave_10best.pth
+inference_config=conf/decode_asr_transformer_ctc0.3_beam5yaml
+#inference_config=conf/decode_asr_transformer_ctc0.3_beam60.yaml
+inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
@@ -73,101 +69,63 @@
_ngpu=0
fi
-if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
- echo "stage 0: Data preparation"
- # Data preparation
- for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
- local/data_prep_librispeech.sh ${data_librispeech}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+ local/download_and_untar.sh ${raw_data} ${data_url} ${part}
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/$train_set; mkdir -p ${feat_train_dir}
-feat_dev_clean_dir=${feats_dir}/${dumpdir}/dev_clean; mkdir -p ${feat_dev_clean_dir}
-feat_dev_other_dir=${feats_dir}/${dumpdir}/dev_other; mkdir -p ${feat_dev_other_dir}
-feat_test_clean_dir=${feats_dir}/${dumpdir}/test_clean; mkdir -p ${feat_test_clean_dir}
-feat_test_other_dir=${feats_dir}/${dumpdir}/test_other; mkdir -p ${feat_test_other_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/$valid_set; mkdir -p ${feat_dev_dir}
-if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- for x in dev_clean dev_other test_clean test_other; do
- utils/compute_fbank.sh --cmd "$train_cmd" --nj 1 --max_lengths 3000 --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/${x} ${exp_dir}/exp/make_fbank/${x} ${fbankdir}/${x}
- utils/fix_data_feat.sh ${fbankdir}/${x}
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ for x in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do
+ local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
done
-
- mkdir ${feats_dir}/data/$train_set
+ mkdir $feats_dir/data/$valid_set
+ dev_sets="dev_clean dev_other"
+ for file in wav.scp text; do
+ ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+ done
+ mkdir $feats_dir/data/$train_set
train_sets="train_clean_100 train_clean_360 train_other_500"
for file in wav.scp text; do
( for f in $train_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$train_set/$file || exit 1;
done
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --max_lengths 3000 --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/$train_set ${exp_dir}/exp/make_fbank/$train_set ${fbankdir}/$train_set
- utils/fix_data_feat.sh ${fbankdir}/$train_set
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/$train_set ${exp_dir}/exp/make_fbank/$train_set
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/$train_set ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/$train_set ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
- ${fbankdir}/dev_clean ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/dev_clean ${feat_dev_clean_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1\
- ${fbankdir}/dev_other ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/dev_other ${feat_dev_other_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
- ${fbankdir}/test_clean ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/test_clean ${feat_test_clean_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
- ${fbankdir}/test_other ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/test_other ${feat_test_other_dir}
-
- cp ${fbankdir}/$train_set/text ${fbankdir}/$train_set/speech_shape ${fbankdir}/$train_set/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev_clean/text ${fbankdir}/dev_clean/speech_shape ${fbankdir}/dev_clean/text_shape ${feat_dev_clean_dir}
- cp ${fbankdir}/dev_other/text ${fbankdir}/dev_other/speech_shape ${fbankdir}/dev_other/text_shape ${feat_dev_other_dir}
- cp ${fbankdir}/test_clean/text ${fbankdir}/test_clean/speech_shape ${fbankdir}/test_clean/text_shape ${feat_test_clean_dir}
- cp ${fbankdir}/test_other/text ${fbankdir}/test_other/speech_shape ${fbankdir}/test_other/text_shape ${feat_test_other_dir}
-
- dev_sets="dev_clean dev_other"
- for file in feats.scp text speech_shape text_shape; do
- ( for f in $dev_sets; do cat $feats_dir/${dumpdir}/$f/$file; done ) | sort -k1 > $feat_dev_dir/$file || exit 1;
- done
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/${train_set} ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/${valid_set} ${feat_dev_dir}
fi
-dict=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
-echo "dictionary: ${dict}"
+echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
### Task dependent. You have to check non-linguistic symbols used in the corpus.
echo "stage 2: Dictionary and Json Data Preparation"
mkdir -p ${feats_dir}/data/lang_char/
- echo "<blank>" > ${dict}
- echo "<s>" >> ${dict}
- echo "</s>" >> ${dict}
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
- spm_train --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
- spm_encode --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${dict}
- echo "<unk>" >> ${dict}
- wc -l ${dict}
-
- vocab_size=$(cat ${dict} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/$train_set
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/$valid_set
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/$train_set
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/$valid_set
+ local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+ local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
fi
-
-# Training Stage
+# LM Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
@@ -181,20 +139,22 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--split_with_space false \
--bpemodel ${bpemodel}.model \
--token_type $token_type \
- --dataset_type $dataset_type \
- --token_list $dict \
- --train_data_file $feats_dir/$dumpdir/${train_set}/ark_txt.scp \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/ark_txt.scp \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
@@ -208,8 +168,8 @@
fi
# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
for dset in ${test_sets}; do
asr_exp=${exp_dir}/exp/${model_dir}
inference_tag="$(basename "${inference_config}" .yaml)"
@@ -220,7 +180,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -241,6 +201,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam1.yaml b/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam1.yaml
new file mode 100644
index 0000000..edc6bab
--- /dev/null
+++ b/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam1.yaml
@@ -0,0 +1,6 @@
+beam_size: 1
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.3
+lm_weight: 0.0
diff --git a/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam20.yaml b/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam20.yaml
new file mode 100644
index 0000000..b2b0427
--- /dev/null
+++ b/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam20.yaml
@@ -0,0 +1,6 @@
+beam_size: 20
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.3
+lm_weight: 0.0
diff --git a/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml b/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml
new file mode 100644
index 0000000..8f7c75d
--- /dev/null
+++ b/egs/librispeech_100h/conformer/conf/decode_asr_transformer_ctc0.3_beam5.yaml
@@ -0,0 +1,6 @@
+beam_size: 5
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.3
+lm_weight: 0.0
diff --git a/egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml b/egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
similarity index 77%
rename from egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml
rename to egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
index 16b7cc0..bd92bb0 100644
--- a/egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml
+++ b/egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
@@ -1,8 +1,8 @@
encoder: conformer
encoder_conf:
- output_size: 512
- attention_heads: 8
- linear_units: 2048
+ output_size: 256
+ attention_heads: 4
+ linear_units: 1024
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
@@ -19,7 +19,7 @@
decoder: transformer
decoder_conf:
- attention_heads: 8
+ attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
@@ -27,13 +27,25 @@
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
-accum_grad: 2
-max_epoch: 50
+accum_grad: 1
+max_epoch: 210
patience: none
init: none
best_model_criterion:
@@ -44,11 +56,11 @@
optim: adam
optim_conf:
- lr: 0.0025
+ lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
- warmup_steps: 40000
+ warmup_steps: 15000
specaug: specaug
specaug_conf:
@@ -64,7 +76,7 @@
time_mask_width_ratio_range:
- 0.
- 0.05
- num_time_mask: 10
+ num_time_mask: 5
dataset_conf:
shuffle: True
@@ -77,4 +89,4 @@
num_workers: 8
log_interval: 50
-normalize: utterance_mvn
\ No newline at end of file
+normalize: None
\ No newline at end of file
diff --git a/egs/librispeech/conformer/local/data_prep_librispeech.sh b/egs/librispeech_100h/conformer/local/data_prep.sh
similarity index 100%
copy from egs/librispeech/conformer/local/data_prep_librispeech.sh
copy to egs/librispeech_100h/conformer/local/data_prep.sh
diff --git a/egs/librispeech_100h/conformer/local/download_and_untar.sh b/egs/librispeech_100h/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+ echo " train-clean-100, train-clean-360, train-other-500."
+ exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0
+fi
+
+
+# sizes of the archive files in bytes. This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release. Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+ size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tar.gz
+ else
+ echo "$data/$part.tar.gz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+ if ! which wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1
+ fi
+ full_url=$url/$part.tar.gz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ if ! wget -P $data --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1
+ fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+ echo "$0: error un-tarring archive $data/$part.tar.gz"
+ exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+ rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech_100h/conformer/local/spm_encode.py b/egs/librispeech_100h/conformer/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True,
+ help="sentencepiece model to use for encoding")
+ parser.add_argument("--inputs", nargs="+", default=['-'],
+ help="input files to filter/encode")
+ parser.add_argument("--outputs", nargs="+", default=['-'],
+ help="path to save encoded outputs")
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+ parser.add_argument("--min-len", type=int, metavar="N",
+ help="filter sentence pairs with fewer than N tokens")
+ parser.add_argument("--max-len", type=int, metavar="N",
+ help="filter sentence pairs with more than N tokens")
+ args = parser.parse_args()
+
+ assert len(args.inputs) == len(args.outputs), \
+ "number of input and output paths should match"
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.output_format == "piece":
+ def encode(l):
+ return sp.EncodeAsPieces(l)
+ elif args.output_format == "id":
+ def encode(l):
+ return list(map(str, sp.EncodeAsIds(l)))
+ else:
+ raise NotImplementedError
+
+ if args.min_len is not None or args.max_len is not None:
+ def valid(line):
+ return (
+ (args.min_len is None or len(line) >= args.min_len) and
+ (args.max_len is None or len(line) <= args.max_len)
+ )
+ else:
+ def valid(lines):
+ return True
+
+ with contextlib.ExitStack() as stack:
+ inputs = [
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-" else sys.stdin
+ for input in args.inputs
+ ]
+ outputs = [
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-" else sys.stdout
+ for output in args.outputs
+ ]
+
+ stats = {
+ "num_empty": 0,
+ "num_filtered": 0,
+ }
+
+ def encode_line(line):
+ line = line.strip()
+ if len(line) > 0:
+ line = encode(line)
+ if valid(line):
+ return line
+ else:
+ stats["num_filtered"] += 1
+ else:
+ stats["num_empty"] += 1
+ return None
+
+ for i, lines in enumerate(zip(*inputs), start=1):
+ enc_lines = list(map(encode_line, lines))
+ if not any(enc_line is None for enc_line in enc_lines):
+ for enc_line, output_h in zip(enc_lines, outputs):
+ print(" ".join(enc_line), file=output_h)
+ if i % 10000 == 0:
+ print("processed {} lines".format(i), file=sys.stderr)
+
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech_100h/conformer/local/spm_train.py b/egs/librispeech_100h/conformer/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/mars/sd/path.sh b/egs/librispeech_100h/conformer/path.sh
similarity index 100%
rename from egs/mars/sd/path.sh
rename to egs/librispeech_100h/conformer/path.sh
diff --git a/egs/librispeech_100h/conformer/run.sh b/egs/librispeech_100h/conformer/run.sh
new file mode 100755
index 0000000..f0db69c
--- /dev/null
+++ b/egs/librispeech_100h/conformer/run.sh
@@ -0,0 +1,219 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=2
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=en
+token_type=bpe
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=5
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=
+data_url=www.openslr.org/resources/12
+
+# bpe model
+nbpe=5000
+bpemode=unigram
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train_clean_100
+valid_set=dev
+test_sets="test_clean test_other dev_clean dev_other"
+
+asr_config=conf/train_asr_conformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+#inference_config=conf/decode_asr_transformer_ctc0.3_beam1.yaml
+inference_config=conf/decode_asr_transformer_ctc0.3_beam5.yaml
+#inference_config=conf/decode_asr_transformer_ctc0.3_beam20.yaml
+inference_asr_model=valid.acc.ave_10best.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ for part in dev-clean test-clean dev-other test-other train-clean-100; do
+ local/download_and_untar.sh ${raw_data} ${data_url} ${part}
+ done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ for x in dev-clean dev-other test-clean test-other train-clean-100; do
+ local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+ done
+ mkdir $feats_dir/data/$valid_set
+ dev_sets="dev_clean dev_other"
+ for file in wav.scp text; do
+ ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ ### Task dependent. You have to check non-linguistic symbols used in the corpus.
+ echo "stage 2: Dictionary and Json Data Preparation"
+ mkdir -p ${feats_dir}/data/lang_char/
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
+ local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+ local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4; then
+ echo "stage 4: ASR Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --bpemodel ${bpemodel}.model \
+ --token_type $token_type \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/data/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode asr \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/compute_wer.py ${_data}/text ${_dir}/text ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
\ No newline at end of file
diff --git a/egs/librispeech_100h/conformer/utils b/egs/librispeech_100h/conformer/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/librispeech_100h/conformer/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
deleted file mode 100644
index 459a741..0000000
--- a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
+++ /dev/null
@@ -1,121 +0,0 @@
-model: sond
-model_conf:
- lsm_weight: 0.0
- length_normalized_loss: true
- max_spk_num: 16
-
-# speech encoder
-encoder: ecapa_tdnn
-encoder_conf:
- # pass by model, equal to feature dim
- # input_size: 80
- pool_size: 20
- stride: 1
-speaker_encoder: conv
-speaker_encoder_conf:
- input_units: 256
- num_layers: 3
- num_units: 256
- kernel_size: 1
- dropout_rate: 0.0
- position_encoder: null
- out_units: 256
- out_norm: false
- auxiliary_states: false
- tf2torch_tensor_name_prefix_torch: speaker_encoder
- tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
-ci_scorer: dot
-ci_scorer_conf: {}
-cd_scorer: san
-cd_scorer_conf:
- input_size: 512
- output_size: 512
- out_units: 1
- attention_heads: 4
- linear_units: 1024
- num_blocks: 4
- dropout_rate: 0.0
- positional_dropout_rate: 0.0
- attention_dropout_rate: 0.0
- # use string "null" to remove input layer
- input_layer: "null"
- pos_enc_class: null
- normalize_before: true
- tf2torch_tensor_name_prefix_torch: cd_scorer
- tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
-# post net
-decoder: fsmn
-decoder_conf:
- in_units: 32
- out_units: 2517
- filter_size: 31
- fsmn_num_layers: 6
- dnn_num_layers: 1
- num_memory_units: 512
- ffn_inner_dim: 512
- dropout_rate: 0.0
- tf2torch_tensor_name_prefix_torch: decoder
- tf2torch_tensor_name_prefix_tf: EAND/post_net
-frontend: wav_frontend
-frontend_conf:
- fs: 16000
- window: povey
- n_mels: 80
- frame_length: 25
- frame_shift: 10
- filter_length_min: -1
- filter_length_max: -1
- lfr_m: 1
- lfr_n: 1
- dither: 0.0
- snip_edges: false
-
-# minibatch related
-batch_type: length
-# 16s * 16k * 16 samples
-batch_bins: 4096000
-num_workers: 8
-
-# optimization related
-accum_grad: 1
-grad_clip: 5
-max_epoch: 50
-val_scheduler_criterion:
- - valid
- - acc
-best_model_criterion:
-- - valid
- - der
- - min
-- - valid
- - forward_steps
- - max
-keep_nbest_models: 10
-
-optim: adam
-optim_conf:
- lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 10000
-
-# without spec aug
-specaug: null
-specaug_conf:
- apply_time_warp: true
- time_warp_window: 5
- time_warp_mode: bicubic
- apply_freq_mask: true
- freq_mask_width_range:
- - 0
- - 30
- num_freq_mask: 2
- apply_time_mask: true
- time_mask_width_range:
- - 0
- - 40
- num_time_mask: 2
-
-log_interval: 50
-# without normalize
-normalize: None
diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh
deleted file mode 100755
index 4516e9f..0000000
--- a/egs/mars/sd/local_run.sh
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env bash
-
-. ./path.sh || exit 1;
-
-# machines configuration
-CUDA_VISIBLE_DEVICES="6,7"
-gpu_num=2
-count=1
-gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
-# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=utils/run.pl
-infer_cmd=utils/run.pl
-
-# general configuration
-feats_dir="." #feature output dictionary
-exp_dir="."
-lang=zh
-dumpdir=dump/raw
-feats_type=raw
-token_type=char
-scp=wav.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
-
-# feature configuration
-feats_dim=
-sample_frequency=16000
-nj=32
-speed_perturb=
-
-# exp tag
-tag="exp1"
-
-. utils/parse_options.sh || exit 1;
-
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-train_set=train
-valid_set=dev
-test_sets="dev test"
-
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
-
-inference_config=conf/decode_asr_transformer.yaml
-inference_asr_model=valid.acc.ave_10best.pb
-
-# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
-ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
-
-if ${gpu_inference}; then
- inference_nj=$[${ngpu}*${njob}]
- _ngpu=1
-else
- inference_nj=$njob
- _ngpu=0
-fi
-
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
-
-# Training Stage
-world_size=$gpu_num # run on one machine
-if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
- mkdir -p ${exp_dir}/exp/${model_dir}
- mkdir -p ${exp_dir}/exp/${model_dir}/log
- INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
- if [ -f $INIT_FILE ];then
- rm -f $INIT_FILE
- fi
- init_method=file://$(readlink -f $INIT_FILE)
- echo "$0: init method is $init_method"
- for ((i = 0; i < $gpu_num; ++i)); do
- {
- rank=$i
- local_rank=$i
- gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
- --gpu_id $gpu_id \
- --use_preprocessor true \
- --token_type char \
- --token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
- --resume true \
- --output_dir ${exp_dir}/exp/${model_dir} \
- --config $asr_config \
- --input_size $feats_dim \
- --ngpu $gpu_num \
- --num_worker_count $count \
- --multiprocessing_distributed true \
- --dist_init_method $init_method \
- --dist_world_size $world_size \
- --dist_rank $rank \
- --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
- } &
- done
- wait
-fi
-
-# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
- for dset in ${test_sets}; do
- asr_exp=${exp_dir}/exp/${model_dir}
- inference_tag="$(basename "${inference_config}" .yaml)"
- _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
- _logdir="${_dir}/logdir"
- if [ -d ${_dir} ]; then
- echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
- exit 0
- fi
- mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
- key_file=${_data}/${scp}
- num_scp_file="$(<${key_file} wc -l)"
- _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
- split_scps=
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob} \
- --gpuid_list ${gpuid_list} \
- --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --asr_train_config "${asr_exp}"/config.yaml \
- --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode asr \
- ${_opts}
-
- for f in token token_int score text; do
- if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | sort -k1 >"${_dir}/${f}"
- fi
- done
- python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
- python utils/proce_text.py ${_data}/text ${_data}/text.proc
- python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
- tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
- cat ${_dir}/text.cer.txt
- done
-fi
-
diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py
deleted file mode 100644
index b207f2d..0000000
--- a/egs/mars/sd/scripts/calculate_shapes.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--input_scp", type=str, required=True)
- parser.add_argument("--out_path")
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.out_path)):
- os.makedirs(os.path.dirname(args.out_path))
-
- task_list = load_scp_as_list(args.input_scp)
- return task_list, None, args
-
- def post(self, result_list, args):
- fd = open(args.out_path, "wt", encoding="utf-8")
- for results in result_list:
- for uttid, shape in results:
- fd.write("{} {}\n".format(uttid, ",".join(shape)))
- fd.close()
-
-
-def process(task_args):
- task_idx, task_list, _, args = task_args
- rst = []
- for uttid, file_path in task_list:
- data = kaldiio.load_mat(file_path)
- shape = [str(x) for x in data.shape]
- rst.append((uttid, shape))
- return rst
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
deleted file mode 100644
index ec1c765..0000000
--- a/egs/mars/sd/scripts/dump_rttm_to_labels.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--rttm_list", type=str, required=True)
- parser.add_argument("--wav_scp_list", type=str, required=True)
- parser.add_argument("--out_dir", type=str, required=True)
- parser.add_argument("--n_spk", type=int, default=8)
- parser.add_argument("--remove_sil", default=False, action="store_true")
- parser.add_argument("--max_overlap", default=0, type=int)
- parser.add_argument("--frame_shift", type=float, default=0.01)
- args = parser.parse_args()
-
- rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
- meeting2rttm = OrderedDict()
- for rttm_path in rttm_list:
- meeting2rttm.update(self.load_rttm(rttm_path))
-
- wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
- meeting_scp = OrderedDict()
- for scp_path in wav_scp_list:
- meeting_scp.update(load_scp_as_dict(scp_path))
-
- if len(meeting_scp) != len(meeting2rttm):
- logging.warning("Number of wav and rttm mismatch {} != {}".format(
- len(meeting_scp), len(meeting2rttm)))
- common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
- logging.warning("Keep {} records.".format(len(common_keys)))
- new_meeting_scp = OrderedDict()
- rm_keys = []
- for key in meeting_scp:
- if key not in common_keys:
- rm_keys.append(key)
- else:
- new_meeting_scp[key] = meeting_scp[key]
- logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
-
- new_meeting2rttm = OrderedDict()
- rm_keys = []
- for key in meeting2rttm:
- if key not in common_keys:
- rm_keys.append(key)
- else:
- new_meeting2rttm[key] = meeting2rttm[key]
- logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
- meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
- return task_list, None, args
-
- @staticmethod
- def load_rttm(rttm_path):
- meeting2rttm = OrderedDict()
- for one_line in open(rttm_path, "rt", encoding="utf-8"):
- mid = one_line.strip().split(" ")[1]
- if mid not in meeting2rttm:
- meeting2rttm[mid] = []
- meeting2rttm[mid].append(one_line.strip())
-
- return meeting2rttm
-
- def post(self, results_list, args):
- pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
- sr=None, frame_shift=0.01):
- frame_shift = int(frame_shift * sr)
- num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
- multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
- for _, st, dur, spk in spk_turns:
- idx = spk_list.index(spk)
-
- st, dur = int(st * sr), int(dur * sr)
- frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
- frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
- multi_label[idx, frame_st:frame_ed] = 1
-
- if remove_sil:
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count)[0]
- multi_label = multi_label[:, idx]
-
- if max_overlap > 0:
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count <= max_overlap)[0]
- multi_label = multi_label[:, idx]
-
- label = multi_label.T
- return label # (T, N)
-
-
-def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
- sr=16000, frame_shift=0.01):
- wav, sr = soundfile.read(wav_path)
- wav_len = len(wav)
- spk_turns = []
- spk_list = []
- for one_line in rttms:
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if spk not in spk_list:
- spk_list.append(spk)
- spk_turns.append((mid, st, dur, spk))
- labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
- return labels, spk_list
-
-
-def process(task_args):
- task_idx, task_list, _, args = task_args
- spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
- "wt", encoding="utf-8")
- out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
- label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
- for mid, wav_path, rttms in task_list:
- meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
- args.sr, args.frame_shift)
- label_writer(mid, meeting_labels)
- spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
-
- spk_list_writer.close()
- label_writer.close()
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
deleted file mode 100644
index cd1ec7b..0000000
--- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import numpy as np
-import os
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import soundfile as sf
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
- def prepare(self, parser):
- assert isinstance(parser, argparse.ArgumentParser)
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("rttm", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--min_dur", type=float, default=2.0)
- parser.add_argument("--max_spk_num", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- meeting2rttms = {}
- for one_line in open(args.rttm, "rt"):
- parts = [x for x in one_line.strip().split(" ") if x != ""]
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if mid not in meeting2rttms:
- meeting2rttms[mid] = []
- meeting2rttms[mid].append(one_line)
-
- task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
- return task_list, None, args
-
- def post(self, result_list, args):
- count = [0, 0]
- for result in result_list:
- count[0] += result[0]
- count[1] += result[1]
- print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
- labels = np.zeros([max_spk_num, length], int)
- spk_list = []
- for one_line in rttms:
- parts = [x for x in one_line.strip().split(" ") if x != ""]
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
- if spk_name.isdigit():
- spk_name = "{}_S{:03d}".format(mid, int(spk_name))
- else:
- spk_name = "{}_{}".format(mid, spk_name)
- if spk_name not in spk_list:
- spk_list.append(spk_name)
- st, dur = int(st*sr), int(dur*sr)
- idx = spk_list.index(spk_name)
- labels[idx, st:st+dur] = 1
- return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
- turns = []
- label = np.sum(multi_label, axis=0) == 1
- spk, in_turn, st = None, False, 0
- for i in range(len(label)):
- if not in_turn and label[i]:
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- if not label[i]:
- in_turn = False
- turns.append([st, i, spk])
- elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
- turns.append([st, i, spk])
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- turns.append([st, len(label), spk])
- return turns
-
-
-def process(task_args):
- task_id, task_list, _, args = task_args
- spk_count = [0, 0]
- for mid, wav_path, rttms in task_list:
- wav, sr = sf.read(wav_path, dtype="int16")
- assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
- multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
- turns = get_nonoverlap_turns(multi_label, spk_list)
- extracted_spk = []
- count = 1
- for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
- if (ed - st) >= args.min_dur * args.sr:
- seg = wav[st: ed]
- save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
- if not os.path.exists(os.path.dirname(save_path)):
- os.makedirs(os.path.dirname(save_path))
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- count += 1
- if spk not in extracted_spk:
- extracted_spk.append(spk)
- if len(extracted_spk) != len(spk_list):
- print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
- mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
- ))
- spk_count[0] += len(extracted_spk)
- spk_count[1] += len(spk_list)
- return spk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
deleted file mode 100644
index e579f51..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("dir", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--n_spk", type=int, default=4)
- parser.add_argument("--remove_sil", default=False, action="store_true")
- args = parser.parse_args()
-
- meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
- rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
- return task_list, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
- multi_label = np.zeros([n_spk, length], dtype=int)
- for _, st, dur, spk in spk_turns:
- st, dur = int(st * sr), int(dur * sr)
- idx = spk_list.index(spk)
- multi_label[idx, st:st+dur] = 1
- if not remove_sil:
- return multi_label.T
-
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count)[0]
- label = multi_label[:, idx].T
- return label # (T, N)
-
-
-def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
- wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
- spk_turns = []
- spk_list = []
- for one_line in open(rttm_path, "rt"):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
- spk = "{}_S{:03d}".format(mid, spk)
- if spk not in spk_list:
- spk_list.append(spk)
- spk_turns.append((mid, st, dur, spk))
- labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
- return labels
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- for mid, wav_path, rttm_path in task_list:
- meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
- save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
- np.save(save_path, meeting_labels.astype(bool))
- print(mid)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
deleted file mode 100644
index 11bc395..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-from tqdm import tqdm
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--chunk_dur", type=float, default=16)
- parser.add_argument("--shift_dur", type=float, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- return wav_scp, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
- for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
- if not os.path.exists(os.path.join(args.out_dir, mid)):
- os.makedirs(os.path.join(args.out_dir, mid))
-
- wav = librosa.load(wav_path, args.sr, True)[0] * 32767
- n_chunk = (len(wav) - chunk_len) // shift_len + 1
- if (len(wav) - chunk_len) % shift_len > 0:
- n_chunk += 1
- for i in range(n_chunk):
- seg = wav[i*shift_len: i*shift_len + chunk_len]
- st = int(float(i*shift_len)/args.sr * 100)
- dur = int(float(len(seg))/args.sr * 100)
- file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
- save_path = os.path.join(args.out_dir, mid, file_name)
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
deleted file mode 100644
index 011bd7c..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("--rttm_scp", type=str)
- parser.add_argument("--seg_file", type=str)
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.seg_file)):
- os.makedirs(os.path.dirname(args.seg_file))
-
- task_list = load_scp_as_list(args.rttm_scp)
- return task_list, None, args
-
- def post(self, results_list, args):
- with open(args.seg_file, "wt", encoding="utf-8") as fd:
- for results in results_list:
- fd.writelines(results)
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- outputs = []
- for mid, rttm_path in task_list:
- spk_turns = []
- length = 0
- for one_line in open(rttm_path, 'rt', encoding="utf-8"):
- parts = one_line.strip().split(" ")
- _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- st, ed = int(st*100), int((st + dur)*100)
- length = ed if ed > length else length
- spk_turns.append([mid, st, ed, spk_name])
- is_sph = np.zeros((length+1, ), dtype=bool)
- for _, st, ed, _ in spk_turns:
- is_sph[st:ed] = True
-
- st, in_speech = 0, False
- for i in range(length+1):
- if not in_speech and is_sph[i]:
- st, in_speech = i, True
- if in_speech and not is_sph[i]:
- in_speech = False
- outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
- mid, st, i, mid, float(st)/100, float(i)/100
- ))
- return outputs
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
deleted file mode 100644
index a2bcd39..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import soundfile
-import kaldiio
-from tqdm import tqdm
-import json
-import os
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import numpy as np
-import argparse
-import random
-
-short_spk_list = []
-def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
- all_utts = spk2utt[spk]
- idx_list = list(range(len(all_utts)))
- random.shuffle(idx_list)
- count = 0
- utt_list = []
- for i in idx_list:
- utt_id = all_utts[i]
- utt_list.append(utt_id)
- count += int(utt2frames[utt_id])
- if count >= total_len:
- break
- if count < 300 and spk not in short_spk_list:
- print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
- short_spk_list.append(spk)
-
- ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
- ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
- ivc = np.concatenate(ivc_list, axis=0)
- ivc = np.mean(ivc, axis=0, keepdims=False)
- return ivc
-
-
-def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
- out_prefix = args.out
-
- ivc_dim = 192
- win_len, win_shift = 400, 160
- label_weights = 2 ** np.array(list(range(args.n_spk)))
- wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
- ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
- label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
-
-
- frames_list = []
- chunk_size = int(args.chunk_size * args.sr)
- chunk_shift = int(args.chunk_shift * args.sr)
- for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
- meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
- num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
- meeting_labels = np.load(labels_scp[mid])
- for i in range(num_chunk):
- st, ed = i*chunk_shift, i*chunk_shift+chunk_size
- seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
- wav_writer(seg_id, meeting_wav[st: ed])
-
- xvec_list = []
- for spk in meeting2spk_list[mid]:
- spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
- xvec_list.append(spk_xvec)
- for _ in range(args.n_spk - len(xvec_list)):
- xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
- xvec = np.row_stack(xvec_list)
- ivc_writer(seg_id, xvec)
-
- wav_label = meeting_labels[st:ed, :]
- frame_num = (ed-st) // win_shift
- # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
- feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
- for i in range(frame_num):
- frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
- feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
- label_writer(seg_id, feat_label)
-
- frames_list.append((mid, feat_label.shape[0]))
- return frames_list
-
-
-def calc_spk_list(rttm_path):
- spk_list = []
- for one_line in open(rttm_path, "rt"):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
- spk = "{}_S{:03d}".format(mid, spk)
- if spk not in spk_list:
- spk_list.append(spk)
-
- return spk_list
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--dir", required=True, type=str, default=None,
- help="feats.scp")
- parser.add_argument("--out", required=True, type=str, default=None,
- help="The prefix of dumpped files.")
- parser.add_argument("--n_spk", type=int, default=4)
- parser.add_argument("--use_lfr", default=False, action="store_true")
- parser.add_argument("--no_pbar", default=False, action="store_true")
- parser.add_argument("--sr", type=int, default=16000)
- parser.add_argument("--chunk_size", type=int, default=16)
- parser.add_argument("--chunk_shift", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.out)):
- os.makedirs(os.path.dirname(args.out))
-
- meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
- labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
- rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
- utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
- utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
- utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
- utt2frames = {}
- for uttid, wav_path in utt2wav.items():
- wav, sr = soundfile.read(wav_path, dtype="int16")
- utt2frames[uttid] = int(len(wav) / sr * 100)
-
- meeting2spk_list = {}
- for mid, rttm_path in rttm_scp:
- meeting2spk_list[mid] = calc_spk_list(rttm_path)
-
- spk2utt = {}
- for utt, spk in utt2spk.items():
- if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
- if spk not in spk2utt:
- spk2utt[spk] = []
- spk2utt[spk].append(utt)
-
- # random.shuffle(feat_scp)
- meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
- total_frames = sum([x[1] for x in meeting_lens])
- print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
-
-
-if __name__ == '__main__':
- main()
diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
deleted file mode 100644
index 1d6f53e..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
+++ /dev/null
@@ -1,110 +0,0 @@
-from __future__ import print_function
-import numpy as np
-import os
-import sys
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import librosa
-import soundfile as sf
-from copy import deepcopy
-import json
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
- def prepare(self, parser):
- assert isinstance(parser, argparse.ArgumentParser)
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("rttm_scp", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--min_dur", type=float, default=2.0)
- parser.add_argument("--max_spk_num", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- rttm_scp = load_scp_as_dict(args.rttm_scp)
- task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
- return task_list, None, args
-
- def post(self, result_list, args):
- count = [0, 0]
- for result in result_list:
- count[0] += result[0]
- count[1] += result[1]
- print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
- labels = np.zeros([max_spk_num, length], int)
- spk_list = []
- for one_line in open(rttm_path, 'rt'):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if spk_name.isdigit():
- spk_name = "{}_S{:03d}".format(mid, int(spk_name))
- if spk_name not in spk_list:
- spk_list.append(spk_name)
- st, dur = int(st*sr), int(dur*sr)
- idx = spk_list.index(spk_name)
- labels[idx, st:st+dur] = 1
- return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
- turns = []
- label = np.sum(multi_label, axis=0) == 1
- spk, in_turn, st = None, False, 0
- for i in range(len(label)):
- if not in_turn and label[i]:
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- if not label[i]:
- in_turn = False
- turns.append([st, i, spk])
- elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
- turns.append([st, i, spk])
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- turns.append([st, len(label), spk])
- return turns
-
-
-def process(task_args):
- task_id, task_list, _, args = task_args
- spk_count = [0, 0]
- for mid, wav_path, rttm_path in task_list:
- wav, sr = sf.read(wav_path, dtype="int16")
- assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
- multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
- turns = get_nonoverlap_turns(multi_label, spk_list)
- extracted_spk = []
- count = 1
- for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
- if (ed - st) >= args.min_dur * args.sr:
- seg = wav[st: ed]
- save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
- if not os.path.exists(os.path.dirname(save_path)):
- os.makedirs(os.path.dirname(save_path))
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- count += 1
- if spk not in extracted_spk:
- extracted_spk.append(spk)
- if len(extracted_spk) != len(spk_list):
- print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
- mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
- ))
- spk_count[0] += len(extracted_spk)
- spk_count[1] += len(spk_list)
- return spk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
deleted file mode 100644
index 8b3195f..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("dir", type=str)
- parser.add_argument("out_dir", type=str)
- args = parser.parse_args()
-
- meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
- vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
- meeting2vad = {}
- for one_line in vad_file:
- uid, mid, st, ed = one_line.strip().split(" ")
- st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
- if mid not in meeting2vad:
- meeting2vad[mid] = []
- meeting2vad[mid].append((uid, st, ed))
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
- return task_list, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- for mid, wav_path, vad_list in task_list:
- wav = librosa.load(wav_path, args.sr, True)[0] * 32767
- seg_list = []
- pos_map = []
- offset = 0
- for uid, st, ed in vad_list:
- seg_list.append(wav[st: ed])
- pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
- offset = offset + ed - st
- out = np.concatenate(seg_list, axis=0)
- save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
- sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
- with open(map_path, "wt", encoding="utf-8") as fd:
- fd.writelines(pos_map)
- print(mid)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py
deleted file mode 100644
index f61b808..0000000
--- a/egs/mars/sd/scripts/simu_chunk_with_labels.py
+++ /dev/null
@@ -1,261 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-import random
-from typing import List, Dict
-from copy import deepcopy
-import json
-logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--label_scp", type=str, required=True)
- parser.add_argument("--wav_scp", type=str, required=True)
- parser.add_argument("--utt2spk", type=str, required=True)
- parser.add_argument("--spk2meeting", type=str, required=True)
- parser.add_argument("--utt2xvec", type=str, required=True)
- parser.add_argument("--out_dir", type=str, required=True)
- parser.add_argument("--chunk_size", type=float, default=16)
- parser.add_argument("--chunk_shift", type=float, default=4)
- parser.add_argument("--frame_shift", type=float, default=0.01)
- parser.add_argument("--embedding_dim", type=int, default=None)
- parser.add_argument("--average_emb_num", type=int, default=0)
- parser.add_argument("--subset", type=int, default=0)
- parser.add_argument("--data_json", type=str, default=None)
- parser.add_argument("--seed", type=int, default=1234)
- parser.add_argument("--log_interval", type=int, default=100)
- args = parser.parse_args()
- random.seed(args.seed)
- np.random.seed(args.seed)
-
- logging.info("Loading data...")
- if not os.path.exists(args.data_json):
- label_list = load_scp_as_list(args.label_scp)
- wav_scp = load_scp_as_dict(args.wav_scp)
- utt2spk = load_scp_as_dict(args.utt2spk)
- utt2xvec = load_scp_as_dict(args.utt2xvec)
- spk2meeting = load_scp_as_dict(args.spk2meeting)
-
- meeting2spks = OrderedDict()
- for spk, meeting in spk2meeting.items():
- if meeting not in meeting2spks:
- meeting2spks[meeting] = []
- meeting2spks[meeting].append(spk)
-
- spk2utts = OrderedDict()
- for utt, spk in utt2spk.items():
- if spk not in spk2utts:
- spk2utts[spk] = []
- spk2utts[spk].append(utt)
-
- os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
- logging.info("Dump data...")
- json.dump({
- "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
- "spk2utts": spk2utts, "meeting2spks": meeting2spks
- }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
- else:
- data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
- label_list = data_dict["label_list"]
- wav_scp = data_dict["wav_scp"]
- utt2xvec = data_dict["utt2xvec"]
- spk2utts = data_dict["spk2utts"]
- meeting2spks = data_dict["meeting2spks"]
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- args.chunk_size = int(args.chunk_size / args.frame_shift)
- args.chunk_shift = int(args.chunk_shift / args.frame_shift)
-
- if args.embedding_dim is None:
- args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
- logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
-
- logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
- len(wav_scp), len(spk2utts), len(meeting2spks)
- ))
- return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
-
- def post(self, results_list, args):
- logging.info("[main]: Got {} chunks.".format(sum(results_list)))
-
-
-def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
- utt_list = spk2utts[spk]
- wav_list = []
- cur_length = 0
- while cur_length < sample_length:
- uttid = random.choice(utt_list)
- wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
- wav_list.append(wav)
- cur_length += len(wav)
- concat_wav = np.concatenate(wav_list, axis=0)
- start = random.randint(0, len(concat_wav) - sample_length)
- return concat_wav[start: start+sample_length]
-
-
-def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
- # process for dummy speaker
- if spk == "None":
- return np.zeros((1, embedding_dim), dtype=np.float32)
-
- # calculate averaged speaker embeddings
- utt_list = spk2utts[spk]
- if average_emb_num == 0 or average_emb_num > len(utt_list):
- xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
- else:
- xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
- xvec = np.concatenate(xvec_list, axis=0)
- xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
- xvec = np.mean(xvec, axis=0)
-
- return xvec
-
-
-def simu_chunk(
- frame_label: np.ndarray,
- sample_label: np.ndarray,
- wav_scp: Dict[str, str],
- utt2xvec: Dict[str, str],
- spk2utts: Dict[str, List[str]],
- meeting2spks: Dict[str, List[str]],
- all_speaker_list: List[str],
- meeting_list: List[str],
- embedding_dim: int,
- average_emb_num: int,
-):
- frame_length, max_spk_num = frame_label.shape
- sample_length = sample_label.shape[0]
- positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
- pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
-
- # get positive speakers
- if len(pos_speaker_list) >= positive_speaker_num:
- pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
- else:
- while len(pos_speaker_list) < positive_speaker_num:
- _spk = random.choice(all_speaker_list)
- if _spk not in pos_speaker_list:
- pos_speaker_list.append(_spk)
-
- # get negative speakers
- negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
- neg_speaker_list = []
- while len(neg_speaker_list) < negative_speaker_num:
- _spk = random.choice(all_speaker_list)
- if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
- neg_speaker_list.append(_spk)
- neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
-
- random.shuffle(pos_speaker_list)
- random.shuffle(neg_speaker_list)
- seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
- this_spk_list = []
- for idx, frame_num in enumerate(frame_label.sum(axis=0)):
- if frame_num > 0:
- spk = pos_speaker_list.pop(0)
- this_spk_list.append(spk)
- simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
- seperated_wav[:, idx] = simu_spk_wav
- else:
- spk = neg_speaker_list.pop(0)
- this_spk_list.append(spk)
-
- # calculate mixed wav
- mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
-
- # shuffle the order of speakers
- shuffle_idx = list(range(max_spk_num))
- random.shuffle(shuffle_idx)
- this_spk_list = [this_spk_list[x] for x in shuffle_idx]
- seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
- frame_label = frame_label.transpose()[shuffle_idx].transpose()
-
- # calculate profile
- profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
- for spk in this_spk_list]
- profile = np.vstack(profile)
- # pse_weights = 2 ** np.arange(max_spk_num)
- # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
- # pse_label = pse_label.astype(str).tolist()
-
- return mixed_wav, seperated_wav, profile, frame_label
-
-
-def process(task_args):
- task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
- logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
-
- out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
- wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
- # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
- profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
- label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
-
- labels_list = []
- total_chunks = 0
- for org_mid, label_path in task_list:
- whole_label = kaldiio.load_mat(label_path)
- # random offset to keep diversity
- rand_shift = random.randint(0, args.chunk_shift)
- num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
- labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
- total_chunks += num_chunk
-
- idx = 0
- simu_chunk_count = 0
- for org_mid, whole_label, rand_shift, num_chunk in labels_list:
- for i in range(num_chunk):
- idx = idx + 1
- st = i * args.chunk_shift + rand_shift
- ed = i * args.chunk_shift + args.chunk_size + rand_shift
- utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
- args.subset + 1, task_idx + 1, org_mid, st, ed
- )
- frame_label = whole_label[st: ed, :]
- sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
- mix_wav, seg_wav, profile, frame_label = simu_chunk(
- frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
- speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
- )
- wav_mix_writer(utt_id, mix_wav)
- # wav_sep_writer(utt_id, seg_wav)
- profile_writer(utt_id, profile)
- label_writer(utt_id, frame_label)
-
- simu_chunk_count += 1
- if simu_chunk_count % args.log_interval == 0:
- logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
- task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
- wav_mix_writer.close()
- # wav_sep_writer.close()
- profile_writer.close()
- label_writer.close()
- logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
- return simu_chunk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/finetune.py
deleted file mode 100644
index 3fa3f9d..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py
deleted file mode 100644
index 862f881..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cantonese-CHS.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/README.md
deleted file mode 100644
index c68a8cd..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/README.md
+++ /dev/null
@@ -1,30 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text.
- - <strong>batch_bins:</strong> # batch size
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/finetune.py
deleted file mode 100644
index f15e3b9..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params.output_dir):
- os.makedirs(params.output_dir, exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params.data_path)
- kwargs = dict(
- model=params.model,
- model_revision=params.model_revision,
- data_dir=ds_dict,
- dataset_type=params.dataset_type,
- work_dir=params.output_dir,
- batch_bins=params.batch_bins,
- max_epoch=params.max_epoch,
- lr=params.lr)
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- from funasr.utils.modelscope_param import modelscope_args
- params = modelscope_args(model="speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline", data_path="./data")
- params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
- params.data_path = "./example_data/" # 鏁版嵁璺緞
- params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
- params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py
deleted file mode 100644
index 347d316..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == '__main__':
- audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
- output_dir = None
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in)
- print(rec_result)
-
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/finetune.py
deleted file mode 100644
index 68d7ba8..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py
deleted file mode 100644
index f82c1f4..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_de.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/finetune.py
deleted file mode 100644
index 397b7ff..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py
deleted file mode 100644
index 98f31b6..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/finetune.py
deleted file mode 100644
index 3846ff6..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py
deleted file mode 100644
index 75e22a0..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_es.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md
deleted file mode 100644
index b68f1e9..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/README.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
- - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
- - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>output_dir:</strong> # result dir
- - <strong>ngpu:</strong> # the number of GPUs for decoding
- - <strong>njob:</strong> # the number of jobs for each GPU
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
-
-- Results
-
-The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
-
-### Inference using local finetuned model
-
-- Modify inference related parameters in `infer_after_finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
-
-- Then you can run the pipeline to finetune with:
-```python
- python infer_after_finetune.py
-```
-
-- Results
-
-The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py
deleted file mode 100644
index 2ecc229..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/finetune.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import os
-
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-
-from funasr.datasets.ms_dataset import MsDataset
-from funasr.utils.modelscope_param import modelscope_args
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params.output_dir):
- os.makedirs(params.output_dir, exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params.data_path)
- kwargs = dict(
- model=params.model,
- data_dir=ds_dict,
- dataset_type=params.dataset_type,
- work_dir=params.output_dir,
- batch_bins=params.batch_bins,
- max_epoch=params.max_epoch,
- lr=params.lr)
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = modelscope_args(model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline", data_path="./data")
- params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
- params.data_path = "./example_data/" # 鏁版嵁璺緞
- params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
- params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
deleted file mode 100644
index e6c39c2..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import os
-import shutil
-from multiprocessing import Pool
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_core(output_dir, split_dir, njob, idx):
- output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
- gpu_id = (int(idx) - 1) // njob
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline",
- output_dir=output_dir_job,
- batch_size=1
- )
- audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
-
-
-def modelscope_infer(params):
- # prepare for multi-GPU decoding
- ngpu = params["ngpu"]
- njob = params["njob"]
- output_dir = params["output_dir"]
- if os.path.exists(output_dir):
- shutil.rmtree(output_dir)
- os.mkdir(output_dir)
- split_dir = os.path.join(output_dir, "split")
- os.mkdir(split_dir)
- nj = ngpu * njob
- wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
- with open(wav_scp_file) as f:
- lines = f.readlines()
- num_lines = len(lines)
- num_job_lines = num_lines // nj
- start = 0
- for i in range(nj):
- end = start + num_job_lines
- file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
- with open(file, "w") as f:
- if i == nj - 1:
- f.writelines(lines[start:])
- else:
- f.writelines(lines[start:end])
- start = end
-
- p = Pool(nj)
- for i in range(nj):
- p.apply_async(modelscope_infer_core,
- args=(output_dir, split_dir, njob, str(i + 1)))
- p.close()
- p.join()
-
- # combine decoding results
- best_recog_path = os.path.join(output_dir, "1best_recog")
- os.mkdir(best_recog_path)
- files = ["text", "token", "score"]
- for file in files:
- with open(os.path.join(best_recog_path, file), "w") as f:
- for i in range(nj):
- job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
- with open(job_file) as f_job:
- lines = f_job.readlines()
- f.writelines(lines)
-
- # If text exists, compute CER
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "token")
- compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
- os.system("tail -n 3 {}".format(os.path.join(best_recog_path, "text.cer")))
-
-
-if __name__ == "__main__":
- params = {}
- params["data_dir"] = "./data/test"
- params["output_dir"] = "./results"
- params["ngpu"] = 1
- params["njob"] = 8
- modelscope_infer(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py
deleted file mode 100644
index 6593f4e..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline/infer_after_finetune.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import json
-import os
-import shutil
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_after_finetune(params):
- # prepare for decoding
- pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
- for file_name in params["required_files"]:
- if file_name == "configuration.json":
- with open(os.path.join(pretrained_model_path, file_name)) as f:
- config_dict = json.load(f)
- config_dict["model"]["am_model_name"] = params["decoding_model_name"]
- with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
- json.dump(config_dict, f, indent=4, separators=(',', ': '))
- else:
- shutil.copy(os.path.join(pretrained_model_path, file_name),
- os.path.join(params["output_dir"], file_name))
- decoding_path = os.path.join(params["output_dir"], "decode_results")
- if os.path.exists(decoding_path):
- shutil.rmtree(decoding_path)
- os.mkdir(decoding_path)
-
- # decoding
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=params["output_dir"],
- output_dir=decoding_path,
- batch_size=1
- )
- audio_in = os.path.join(params["data_dir"], "wav.scp")
- inference_pipeline(audio_in=audio_in)
-
- # computer CER if GT text is set
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(decoding_path, "1best_recog/token")
- compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
- os.system("tail -n 3 {}".format(os.path.join(decoding_path, "text.cer")))
-
-
-if __name__ == '__main__':
- params = {}
- params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-offline"
- params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data/test"
- params["decoding_model_name"] = "20epoch.pb"
- modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/finetune.py
deleted file mode 100644
index 4746cc2..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py
deleted file mode 100644
index 627d132..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_fr.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/finetune.py
deleted file mode 100644
index 985b838..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py
deleted file mode 100644
index e53c37e..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_id.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/finetune.py
deleted file mode 100644
index 5485ff5..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py
deleted file mode 100644
index 68cc41d..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ja.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/finetune.py
deleted file mode 100644
index fd9c442..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py
deleted file mode 100644
index b87bcbb..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ko.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/finetune.py
deleted file mode 100644
index 512b844..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py
deleted file mode 100644
index 4a43e7c..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_pt.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-pt-16k-common-vocab1617-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/finetune.py
deleted file mode 100644
index 432266d..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py
deleted file mode 100644
index 3c9d364..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_ru.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/finetune.py
deleted file mode 100644
index 3a90ed2..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params["output_dir"]):
- os.makedirs(params["output_dir"], exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params["data_dir"])
- kwargs = dict(
- model=params["model"],
- model_revision=params["model_revision"],
- data_dir=ds_dict,
- dataset_type=params["dataset_type"],
- work_dir=params["output_dir"],
- batch_bins=params["batch_bins"],
- max_epoch=params["max_epoch"],
- lr=params["lr"])
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = {}
- params["output_dir"] = "./checkpoint"
- params["data_dir"] = "./data"
- params["batch_bins"] = 2000
- params["dataset_type"] = "small"
- params["max_epoch"] = 50
- params["lr"] = 0.00005
- params["model"] = "damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline"
- params["model_revision"] = None
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py
deleted file mode 100644
index 4218f3d..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline/infer.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == "__main__":
- audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_vi.wav"
- output_dir = "./results"
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in, param_dict={"decoding_model":"offline"})
- print(rec_result)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/README.md
deleted file mode 100644
index c68a8cd..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/README.md
+++ /dev/null
@@ -1,30 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text.
- - <strong>batch_bins:</strong> # batch size
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/finetune.py
deleted file mode 100644
index 73aae7d..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params.output_dir):
- os.makedirs(params.output_dir, exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params.data_path)
- kwargs = dict(
- model=params.model,
- model_revision=params.model_revision,
- data_dir=ds_dict,
- dataset_type=params.dataset_type,
- work_dir=params.output_dir,
- batch_bins=params.batch_bins,
- max_epoch=params.max_epoch,
- lr=params.lr)
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- from funasr.utils.modelscope_param import modelscope_args
- params = modelscope_args(model="damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline", data_path="./data")
- params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
- params.data_path = "./example_data/" # 鏁版嵁璺緞
- params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
- params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py
deleted file mode 100644
index 3520989..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == '__main__':
- audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
- output_dir = None
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in)
- print(rec_result)
-
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/README.md
deleted file mode 100644
index 9a84f9b..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/README.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained UniASR Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
- - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
- - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>output_dir:</strong> # result dir
- - <strong>ngpu:</strong> # the number of GPUs for decoding
- - <strong>njob:</strong> # the number of jobs for each GPU
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
-
-- Results
-
-The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
-
-### Inference using local finetuned model
-
-- Modify inference related parameters in `infer_after_finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pb`
-
-- Then you can run the pipeline to finetune with:
-```python
- python infer_after_finetune.py
-```
-
-- Results
-
-The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set.
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/finetune.py
deleted file mode 100644
index b2325b2..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/finetune.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import os
-
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-
-from funasr.datasets.ms_dataset import MsDataset
-from funasr.utils.modelscope_param import modelscope_args
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params.output_dir):
- os.makedirs(params.output_dir, exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params.data_path)
- kwargs = dict(
- model=params.model,
- data_dir=ds_dict,
- dataset_type=params.dataset_type,
- work_dir=params.output_dir,
- batch_bins=params.batch_bins,
- max_epoch=params.max_epoch,
- lr=params.lr)
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- params = modelscope_args(model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline", data_path="./data")
- params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
- params.data_path = "./example_data/" # 鏁版嵁璺緞
- params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
- params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
deleted file mode 100644
index 13d2a2e..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import os
-import shutil
-from multiprocessing import Pool
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_core(output_dir, split_dir, njob, idx):
- output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
- gpu_id = (int(idx) - 1) // njob
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline",
- output_dir=output_dir_job,
- batch_size=1
- )
- audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipeline(audio_in=audio_in)
-
-def modelscope_infer(params):
- # prepare for multi-GPU decoding
- ngpu = params["ngpu"]
- njob = params["njob"]
- output_dir = params["output_dir"]
- if os.path.exists(output_dir):
- shutil.rmtree(output_dir)
- os.mkdir(output_dir)
- split_dir = os.path.join(output_dir, "split")
- os.mkdir(split_dir)
- nj = ngpu * njob
- wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
- with open(wav_scp_file) as f:
- lines = f.readlines()
- num_lines = len(lines)
- num_job_lines = num_lines // nj
- start = 0
- for i in range(nj):
- end = start + num_job_lines
- file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
- with open(file, "w") as f:
- if i == nj - 1:
- f.writelines(lines[start:])
- else:
- f.writelines(lines[start:end])
- start = end
-
- p = Pool(nj)
- for i in range(nj):
- p.apply_async(modelscope_infer_core,
- args=(output_dir, split_dir, njob, str(i + 1)))
- p.close()
- p.join()
-
- # combine decoding results
- best_recog_path = os.path.join(output_dir, "1best_recog")
- os.mkdir(best_recog_path)
- files = ["text", "token", "score"]
- for file in files:
- with open(os.path.join(best_recog_path, file), "w") as f:
- for i in range(nj):
- job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
- with open(job_file) as f_job:
- lines = f_job.readlines()
- f.writelines(lines)
-
- # If text exists, compute CER
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "text")
- compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
-
-
-if __name__ == "__main__":
- params = {}
- params["data_dir"] = "./data/test"
- params["output_dir"] = "./results"
- params["ngpu"] = 1
- params["njob"] = 1
- modelscope_infer(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
deleted file mode 100644
index 1e9c4d1..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/infer_after_finetune.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import json
-import os
-import shutil
-
-from multiprocessing import Pool
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-from funasr.utils.compute_wer import compute_wer
-
-
-def modelscope_infer_after_finetune_core(model_dir, output_dir, split_dir, njob, idx):
- output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
- gpu_id = (int(idx) - 1) // njob
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=model_dir,
- output_dir=output_dir_job,
- batch_size=1
- )
- audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
- inference_pipeline(audio_in=audio_in)
-
-def modelscope_infer_after_finetune(params):
- # prepare for multi-GPU decoding
- model_dir = params["model_dir"]
- pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
- for file_name in params["required_files"]:
- if file_name == "configuration.json":
- with open(os.path.join(pretrained_model_path, file_name)) as f:
- config_dict = json.load(f)
- config_dict["model"]["am_model_name"] = params["decoding_model_name"]
- with open(os.path.join(model_dir, "configuration.json"), "w") as f:
- json.dump(config_dict, f, indent=4, separators=(',', ': '))
- else:
- shutil.copy(os.path.join(pretrained_model_path, file_name),
- os.path.join(model_dir, file_name))
- ngpu = params["ngpu"]
- njob = params["njob"]
- output_dir = params["output_dir"]
- if os.path.exists(output_dir):
- shutil.rmtree(output_dir)
- os.mkdir(output_dir)
- split_dir = os.path.join(output_dir, "split")
- os.mkdir(split_dir)
- nj = ngpu * njob
- wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
- with open(wav_scp_file) as f:
- lines = f.readlines()
- num_lines = len(lines)
- num_job_lines = num_lines // nj
- start = 0
- for i in range(nj):
- end = start + num_job_lines
- file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
- with open(file, "w") as f:
- if i == nj - 1:
- f.writelines(lines[start:])
- else:
- f.writelines(lines[start:end])
- start = end
-
- p = Pool(nj)
- for i in range(nj):
- p.apply_async(modelscope_infer_after_finetune_core,
- args=(model_dir, output_dir, split_dir, njob, str(i + 1)))
- p.close()
- p.join()
-
- # combine decoding results
- best_recog_path = os.path.join(output_dir, "1best_recog")
- os.mkdir(best_recog_path)
- files = ["text", "token", "score"]
- for file in files:
- with open(os.path.join(best_recog_path, file), "w") as f:
- for i in range(nj):
- job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
- with open(job_file) as f_job:
- lines = f_job.readlines()
- f.writelines(lines)
-
- # If text exists, compute CER
- text_in = os.path.join(params["data_dir"], "text")
- if os.path.exists(text_in):
- text_proc_file = os.path.join(best_recog_path, "token")
- compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
-
-if __name__ == '__main__':
- params = {}
- params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline"
- params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
- params["model_dir"] = "./checkpoint"
- params["output_dir"] = "./results"
- params["data_dir"] = "./data/test"
- params["decoding_model_name"] = "20epoch.pb"
- params["ngpu"] = 1
- params["njob"] = 1
- modelscope_infer_after_finetune(params)
-
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/README.md b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/README.md
deleted file mode 100644
index c68a8cd..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/README.md
+++ /dev/null
@@ -1,30 +0,0 @@
-# ModelScope Model
-
-## How to finetune and infer using a pretrained Paraformer-large Model
-
-### Finetune
-
-- Modify finetune training related parameters in `finetune.py`
- - <strong>output_dir:</strong> # result dir
- - <strong>data_dir:</strong> # the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text.
- - <strong>batch_bins:</strong> # batch size
- - <strong>max_epoch:</strong> # number of training epoch
- - <strong>lr:</strong> # learning rate
-
-- Then you can run the pipeline to finetune with:
-```python
- python finetune.py
-```
-
-### Inference
-
-Or you can use the finetuned model for inference directly.
-
-- Setting parameters in `infer.py`
- - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
- - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
-
-- Then you can run the pipeline to infer with:
-```python
- python infer.py
-```
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/finetune.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/finetune.py
deleted file mode 100644
index b18296e..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/finetune.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os
-from modelscope.metainfo import Trainers
-from modelscope.trainers import build_trainer
-from funasr.datasets.ms_dataset import MsDataset
-
-
-def modelscope_finetune(params):
- if not os.path.exists(params.output_dir):
- os.makedirs(params.output_dir, exist_ok=True)
- # dataset split ["train", "validation"]
- ds_dict = MsDataset.load(params.data_path)
- kwargs = dict(
- model=params.model,
- model_revision=params.model_revision,
- data_dir=ds_dict,
- dataset_type=params.dataset_type,
- work_dir=params.output_dir,
- batch_bins=params.batch_bins,
- max_epoch=params.max_epoch,
- lr=params.lr)
- trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
- trainer.train()
-
-
-if __name__ == '__main__':
- from funasr.utils.modelscope_param import modelscope_args
- params = modelscope_args(model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline", data_path="./data")
- params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
- params.data_path = "./example_data/" # 鏁版嵁璺緞
- params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
- params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-
- modelscope_finetune(params)
diff --git a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py b/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py
deleted file mode 100644
index 8ec4288..0000000
--- a/egs_modelscope/asr/uniasr/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/infer.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == '__main__':
- audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
- output_dir = None
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline",
- output_dir=output_dir,
- )
- rec_result = inference_pipeline(audio_in=audio_in)
- print(rec_result)
-
diff --git a/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/demo.py
similarity index 100%
rename from egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
rename to egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/demo.py
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py
index a6629cd..c449ab2 100644
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py
@@ -1,7 +1,3 @@
-
-##################text浜岃繘鍒舵暟鎹�#####################
-inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
-
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
@@ -13,9 +9,12 @@
inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
- output_dir="./tmp/"
+ model_revision = 'v1.0.2'
)
+##################text浜岃繘鍒舵暟鎹�#####################
+inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+
vads = inputs.split("|")
rec_result_all="outputs:"
param_dict = {"cache": []}
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py
index 20994d3..7383a58 100644
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/demo.py
@@ -1,14 +1,4 @@
-##################text.scp鏂囦欢璺緞###################
-inputs = "./egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt"
-
-##################text浜岃繘鍒舵暟鎹�#####################
-#inputs = "鎴戜滑閮芥槸鏈ㄥご浜轰笉浼氳璇濅笉浼氬姩"
-
-##################text鏂囦欢url#######################
-#inputs = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
-
-
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -19,5 +9,14 @@
output_dir="./tmp/"
)
+##################text.scp###################
+# inputs = "./egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt"
+
+##################text#####################
+#inputs = "鎴戜滑閮芥槸鏈ㄥご浜轰笉浼氳璇濅笉浼氬姩"
+
+##################text file url#######################
+inputs = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
+
rec_result = inference_pipeline(text_in=inputs)
print(rec_result)
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
index 3db6f7d..9e80d2b 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/infer.py
@@ -7,8 +7,9 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-# 鍒濆鍖栨帹鐞� pipeline
-# 褰撲互鍘熷闊抽浣滀负杈撳叆鏃朵娇鐢ㄩ厤缃枃浠� sond.yaml锛屽苟璁剧疆 mode 涓簊ond_demo
+# initialize the pipeline for inference
+# when using the raw waveform files to inference, please use the config file `sond.yaml`
+# and set mode to `sond_demo`
inference_diar_pipline = pipeline(
mode="sond_demo",
num_workers=0,
@@ -19,7 +20,8 @@
sv_model_revision="master",
)
-# 浠� audio_list 浣滀负杈撳叆锛屽叾涓涓�涓煶棰戜负寰呮娴嬭闊筹紝鍚庨潰鐨勯煶棰戜负涓嶅悓璇磋瘽浜虹殑澹扮汗娉ㄥ唽璇煶
+# use audio_list as the input, where the first one is the record to be detected
+# and the following files are enrollments for different speakers
audio_list = [
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
index db10193..dc867b0 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
@@ -7,8 +7,9 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-# 鍒濆鍖栨帹鐞� pipeline
-# 褰撲互鍘熷闊抽浣滀负杈撳叆鏃朵娇鐢ㄩ厤缃枃浠� sond.yaml锛屽苟璁剧疆 mode 涓簊ond_demo
+# initialize the pipeline for inference
+# when using the raw waveform files to inference, please use the config file `sond.yaml`
+# and set mode to `sond_demo`
inference_diar_pipline = pipeline(
mode="sond_demo",
num_workers=0,
@@ -19,7 +20,8 @@
sv_model_revision="master",
)
-# 浠� audio_list 浣滀负杈撳叆锛屽叾涓涓�涓煶棰戜负寰呮娴嬭闊筹紝鍚庨潰鐨勯煶棰戜负涓嶅悓璇磋瘽浜虹殑澹扮汗娉ㄥ唽璇煶
+# use audio_list as the input, where the first one is the record to be detected
+# and the following files are enrollments for different speakers
audio_list = [
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav",
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav",
diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py b/egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/demo.py
similarity index 100%
rename from egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/infer.py
rename to egs_modelscope/speaker_verification/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/demo.py
diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/demo.py
similarity index 100%
rename from egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer.py
rename to egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/demo.py
diff --git a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer_sv.py b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer_sv.py
index c51313d..7a53827 100644
--- a/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer_sv.py
+++ b/egs_modelscope/speaker_verification/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/infer_sv.py
@@ -7,13 +7,13 @@
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
)
- # 涓や釜璇煶涓虹浉鍚岃璇濅汉
+ # the same speaker
rec_result = inference_sv_pipline(audio_in=(
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
print("Similarity", rec_result["scores"])
- # 涓や釜璇煶涓轰笉鍚岃璇濅汉
+ # different speaker
rec_result = inference_sv_pipline(audio_in=(
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
diff --git a/egs_modelscope/tp/TEMPLATE/infer.py b/egs_modelscope/tp/TEMPLATE/infer.py
index 6a7e496..732c984 100644
--- a/egs_modelscope/tp/TEMPLATE/infer.py
+++ b/egs_modelscope/tp/TEMPLATE/infer.py
@@ -8,6 +8,7 @@
inference_pipeline = pipeline(
task=Tasks.speech_timestamp,
model=args.model,
+ model_revision='v1.1.0',
output_dir=args.output_dir,
batch_size=args.batch_size,
)
@@ -21,7 +22,7 @@
parser.add_argument('--model', type=str, default="damo/speech_timestamp_prediction-v1-16k-offline")
parser.add_argument('--audio_in', type=str, default="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav")
parser.add_argument('--text_in', type=str, default="涓� 涓� 涓� 澶� 骞� 娲� 鍥� 瀹� 涓� 浠� 涔� 璺� 鍒� 瑗� 澶� 骞� 娲� 鏉� 浜� 鍛�")
- parser.add_argument('--output_dir', type=str, default="./results/")
+ parser.add_argument('--output_dir', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--gpuid', type=str, default="0")
args = parser.parse_args()
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
index bcc5128..3116f6d 100644
--- a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo.py
@@ -4,6 +4,7 @@
inference_pipeline = pipeline(
task=Tasks.speech_timestamp,
model='damo/speech_timestamp_prediction-v1-16k-offline',
+ model_revision='v1.1.0',
output_dir=None)
rec_result = inference_pipeline(
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
new file mode 100644
index 0000000..d9d413b
--- /dev/null
+++ b/funasr/bin/asr_infer.py
@@ -0,0 +1,1834 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import sys
+import time
+import copy
+import os
+import codecs
+import tempfile
+import requests
+from pathlib import Path
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from packaging.version import parse as V
+from typeguard import check_argument_types
+from typeguard import check_return_type
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.beam_search import BeamSearch
+# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
+from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis as HypothesisSAASR
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.bin.punc_infer import Text2Punc
+from funasr.utils.vad_utils import slice_padding_fbank
+from funasr.tasks.vad import VADTask
+from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
+from funasr.tasks.asr import frontend_choices
+
+class Speech2Text:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend == 'wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ else:
+ from funasr.tasks.asr import frontend_choices
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, None, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+ from funasr.modules.beam_search.beam_search import BeamSearch
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ) -> List[
+ Tuple[
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, _ = self.asr_model.encode(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ assert len(enc) == 1, len(enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+class Speech2TextParaformer:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2TextParaformer("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ if asr_model.ctc != None:
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = asr_model.token_list
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+ from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from str, local file or url
+ self.hotword_list = None
+ self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+
+ is_use_lm = lm_weight != 0.0 and lm_file is not None
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
+ beam_search = None
+ self.beam_search = beam_search
+ logging.info(f"Beam_search: {self.beam_search}")
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+ self.encoder_downsampling_factor = 1
+ if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
+ self.encoder_downsampling_factor = 4
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ begin_time: int = 0, end_time: int = None,
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, enc_len = self.asr_model.encode(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ # assert len(enc) == 1, len(enc)
+ enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
+
+ predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+ predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
+ if torch.max(pre_token_length) < 1:
+ return []
+ if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
+ if self.hotword_list:
+ logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ else:
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+ decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+ if isinstance(self.asr_model, BiCifParaformer):
+ _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+ pre_token_length) # test no bias cif2
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = enc[i, :enc_len[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ timestamp = []
+ if isinstance(self.asr_model, BiCifParaformer):
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
+ us_peaks[i][:enc_len[i]*3],
+ copy.copy(token),
+ vad_offset=begin_time)
+ results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
+
+
+ # assert check_return_type(results)
+ return results
+
+ def generate_hotwords_list(self, hotword_list_or_file):
+ # for None
+ if hotword_list_or_file is None:
+ hotword_list = None
+ # for local txt inputs
+ elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords from local txt...")
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
+ elif hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
+ hotword_list = []
+ hotword_str_list = []
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
+ # for text str input
+ elif not hotword_list_or_file.endswith('.txt'):
+ logging.info("Attempting to parse hotwords as str...")
+ hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ else:
+ hotword_list = None
+ return hotword_list
+
+class Speech2TextParaformerOnline:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2TextParaformerOnline("asr_config.yml", "asr.pth")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ frontend_conf: dict = None,
+ hotword_list_or_file: str = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ if asr_model.ctc != None:
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = asr_model.token_list
+ scorers.update(
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+ from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ # 6. [Optional] Build hotword list from str, local file or url
+
+ is_use_lm = lm_weight != 0.0 and lm_file is not None
+ if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
+ beam_search = None
+ self.beam_search = beam_search
+ logging.info(f"Beam_search: {self.beam_search}")
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+ self.encoder_downsampling_factor = 1
+ if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
+ self.encoder_downsampling_factor = 4
+
+ @torch.no_grad()
+ def __call__(
+ self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+ results = []
+ cache_en = cache["encoder"]
+ if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
+ if cache_en["start_idx"] == 0:
+ return []
+ cache_en["tail_chunk"] = True
+ feats = cache_en["feats"]
+ feats_len = torch.tensor([feats.shape[1]])
+ self.asr_model.frontend = None
+ results = self.infer(feats, feats_len, cache)
+ return results
+ else:
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+
+ if feats.shape[1] != 0:
+ results = self.infer(feats, feats_len, cache)
+
+ return results
+
+ @torch.no_grad()
+ def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None):
+ batch = {"speech": feats, "speech_lengths": feats_len}
+ batch = to_device(batch, device=self.device)
+ # b. Forward Encoder
+ enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ # assert len(enc) == 1, len(enc)
+ enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
+
+ predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
+ pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
+ if torch.max(pre_token_length) < 1:
+ return []
+ decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
+ decoder_out = decoder_outs
+
+ results = []
+ b, n, d = decoder_out.size()
+ for i in range(b):
+ x = enc[i, :enc_len[i], :]
+ am_scores = decoder_out[i, :pre_token_length[i], :]
+ if self.beam_search is not None:
+ nbest_hyps = self.beam_search(
+ x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+ else:
+ yseq = am_scores.argmax(dim=-1)
+ score = am_scores.max(dim=-1)[0]
+ score = torch.sum(score, dim=-1)
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ yseq = torch.tensor(
+ [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
+ )
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+ token = " ".join(token)
+
+ results.append(token)
+
+ # assert check_return_type(results)
+ return results
+
+class Speech2TextUniASR:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2TextUniASR("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ token_num_relax: int = 1,
+ decoding_ind: int = 0,
+ decoding_mode: str = "model1",
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ scorers = {}
+ from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+ if decoding_mode == "model1":
+ decoder = asr_model.decoder
+ else:
+ decoder = asr_model.decoder2
+
+ if asr_model.ctc != None:
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+ from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+ for scorer in scorers.values():
+ if isinstance(scorer, torch.nn.Module):
+ scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+ # logging.info(f"Beam_search: {beam_search}")
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.token_num_relax = token_num_relax
+ self.decoding_ind = decoding_ind
+ self.decoding_mode = decoding_mode
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ) -> List[
+ Tuple[
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ feats_raw = feats.clone().to(self.device)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+ # b. Forward Encoder
+ _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+ assert len(enc) == 1, len(enc)
+ if self.decoding_mode == "model1":
+ predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
+ else:
+ enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
+ predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
+
+ scama_mask = predictor_outs[4]
+ pre_token_length = predictor_outs[1]
+ pre_acoustic_embeds = predictor_outs[0]
+ maxlen = pre_token_length.sum().item() + self.token_num_relax
+ minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
+ minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+ token = list(filter(lambda x: x != "<gbg>", token))
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+
+class Speech2TextMFCCA:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2TextMFCCA("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ lm.to(device)
+ scorers["lm"] = lm.lm
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+ # beam_search.__class__ = BatchBeamSearch
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ) -> List[
+ Tuple[
+ Optional[str],
+ List[str],
+ List[int],
+ Union[Hypothesis],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if (speech.dim() == 3):
+ speech = torch.squeeze(speech, 2)
+ # speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ speech = speech.to(getattr(torch, self.dtype))
+ # lenghts: (1,)
+ lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ batch = {"speech": speech, "speech_lengths": lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, _ = self.asr_model.encode(**batch)
+
+ assert len(enc) == 1, len(enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
+
+
+class Speech2TextTransducer:
+ """Speech2Text class for Transducer models.
+ Args:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model config path.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ device: Device to use for inference.
+ beam_size: Size of beam during search.
+ dtype: Data type.
+ lm_weight: Language model weight.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ nbest: Number of final hypothesis.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ beam_search_config: Dict[str, Any] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ beam_size: int = 5,
+ dtype: str = "float32",
+ lm_weight: float = 1.0,
+ quantize_asr_model: bool = False,
+ quantize_modules: List[str] = None,
+ quantize_dtype: str = "qint8",
+ nbest: int = 1,
+ streaming: bool = False,
+ simu_streaming: bool = False,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ display_partial_hypotheses: bool = False,
+ ) -> None:
+ """Construct a Speech2Text object."""
+ super().__init__()
+
+ assert check_argument_types()
+ from funasr.tasks.asr import ASRTransducerTask
+ asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+ if quantize_asr_model:
+ if quantize_modules is not None:
+ if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
+ raise ValueError(
+ "Only 'Linear' and 'LSTM' modules are currently supported"
+ " by PyTorch and in --quantize_modules"
+ )
+
+ q_config = set([getattr(torch.nn, q) for q in quantize_modules])
+ else:
+ q_config = {torch.nn.Linear}
+
+ if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
+ raise ValueError(
+ "float16 dtype for dynamic quantization is not supported with torch"
+ " version < 1.5.0. Switching to qint8 dtype instead."
+ )
+ q_dtype = getattr(torch, quantize_dtype)
+
+ asr_model = torch.quantization.quantize_dynamic(
+ asr_model, q_config, dtype=q_dtype
+ ).eval()
+ else:
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, device
+ )
+ lm_scorer = lm.lm
+ else:
+ lm_scorer = None
+
+ # 4. Build BeamSearch object
+ if beam_search_config is None:
+ beam_search_config = {}
+
+ beam_search = BeamSearchTransducer(
+ asr_model.decoder,
+ asr_model.joint_network,
+ beam_size,
+ lm=lm_scorer,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ **beam_search_config,
+ )
+
+ token_list = asr_model.token_list
+
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ self.beam_search = beam_search
+ self.streaming = streaming
+ self.simu_streaming = simu_streaming
+ self.chunk_size = max(chunk_size, 0)
+ self.left_context = left_context
+ self.right_context = max(right_context, 0)
+
+ if not streaming or chunk_size == 0:
+ self.streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+
+ if not simu_streaming or chunk_size == 0:
+ self.simu_streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+
+ self.frontend = frontend
+ self.window_size = self.chunk_size + self.right_context
+
+ if self.streaming:
+ self._ctx = self.asr_model.encoder.get_encoder_input_size(
+ self.window_size
+ )
+
+ self.last_chunk_length = (
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ )
+ self.reset_inference_cache()
+
+ def reset_inference_cache(self) -> None:
+ """Reset Speech2Text parameters."""
+ self.frontend_cache = None
+
+ self.asr_model.encoder.reset_streaming_cache(
+ self.left_context, device=self.device
+ )
+ self.beam_search.reset_inference_cache()
+
+ self.num_processed_frames = torch.tensor([[0]], device=self.device)
+
+ @torch.no_grad()
+ def streaming_decode(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ is_final: bool = True,
+ ) -> List[HypothesisTransducer]:
+ """Speech2Text streaming call.
+ Args:
+ speech: Chunk of speech data. (S)
+ is_final: Whether speech corresponds to the final chunk of data.
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if is_final:
+ if self.streaming and speech.size(0) < self.last_chunk_length:
+ pad = torch.zeros(
+ self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
+ )
+ speech = torch.cat([speech, pad],
+ dim=0) # feats, feats_length = self.apply_frontend(speech, is_final=is_final)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.chunk_forward(
+ feats,
+ feats_lengths,
+ self.num_processed_frames,
+ chunk_size=self.chunk_size,
+ left_context=self.left_context,
+ right_context=self.right_context,
+ )
+ nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
+
+ self.num_processed_frames += self.chunk_size
+
+ if is_final:
+ self.reset_inference_cache()
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context,
+ self.right_context)
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+
+ enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ def hypotheses_to_results(self, nbest_hyps: List[HypothesisTransducer]) -> List[Any]:
+ """Build partial or final results from the hypotheses.
+ Args:
+ nbest_hyps: N-best hypothesis.
+ Returns:
+ results: Results containing different representation for the hypothesis.
+ """
+ results = []
+
+ for hyp in nbest_hyps:
+ token_int = list(filter(lambda x: x != 0, hyp.yseq))
+
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
+
+ assert check_return_type(results)
+
+ return results
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ) -> Speech2Text:
+ """Build Speech2Text instance from the pretrained model.
+ Args:
+ model_tag: Model tag of the pretrained models.
+ Return:
+ : Speech2Text instance.
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2TextTransducer(**kwargs)
+
+
+class Speech2TextSAASR:
+ """Speech2Text class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2text = Speech2TextSAASR("asr_config.yml", "asr.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2text(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ beam_size: int = 20,
+ ctc_weight: float = 0.5,
+ lm_weight: float = 1.0,
+ ngram_weight: float = 0.9,
+ penalty: float = 0.0,
+ nbest: int = 1,
+ streaming: bool = False,
+ frontend_conf: dict = None,
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build ASR model
+ from funasr.tasks.sa_asr import ASRTask
+ scorers = {}
+ asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device
+ )
+ frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend == 'wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ else:
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
+
+ logging.info("asr_model: {}".format(asr_model))
+ logging.info("asr_train_args: {}".format(asr_train_args))
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+ decoder = asr_model.decoder
+
+ ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+ token_list = asr_model.token_list
+ scorers.update(
+ decoder=decoder,
+ ctc=ctc,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+ # 2. Build Language model
+ if lm_train_config is not None:
+ lm, lm_train_args = LMTask.build_model_from_file(
+ lm_train_config, lm_file, None, device
+ )
+ scorers["lm"] = lm.lm
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ # 4. Build BeamSearch object
+ # transducer is not supported now
+ beam_search_transducer = None
+ from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
+
+ weights = dict(
+ decoder=1.0 - ctc_weight,
+ ctc=ctc_weight,
+ lm=lm_weight,
+ ngram=ngram_weight,
+ length_bonus=penalty,
+ )
+ beam_search = BeamSearch(
+ beam_size=beam_size,
+ weights=weights,
+ scorers=scorers,
+ sos=asr_model.sos,
+ eos=asr_model.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+ )
+
+ # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+ if token_type is None:
+ token_type = asr_train_args.token_type
+ if bpemodel is None:
+ bpemodel = asr_train_args.bpemodel
+
+ if token_type is None:
+ tokenizer = None
+ elif token_type == "bpe":
+ if bpemodel is not None:
+ tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+ else:
+ tokenizer = None
+ else:
+ tokenizer = build_tokenizer(token_type=token_type)
+ converter = TokenIDConverter(token_list=token_list)
+ logging.info(f"Text tokenizer: {tokenizer}")
+
+ self.asr_model = asr_model
+ self.asr_train_args = asr_train_args
+ self.converter = converter
+ self.tokenizer = tokenizer
+ self.beam_search = beam_search
+ self.beam_search_transducer = beam_search_transducer
+ self.maxlenratio = maxlenratio
+ self.minlenratio = minlenratio
+ self.device = device
+ self.dtype = dtype
+ self.nbest = nbest
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray],
+ profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
+ ) -> List[
+ Tuple[
+ Optional[str],
+ Optional[str],
+ List[str],
+ List[int],
+ Union[HypothesisSAASR],
+ ]
+ ]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, text_id, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if isinstance(profile, np.ndarray):
+ profile = torch.tensor(profile)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ asr_enc, _, spk_enc = self.asr_model.encode(**batch)
+ if isinstance(asr_enc, tuple):
+ asr_enc = asr_enc[0]
+ if isinstance(spk_enc, tuple):
+ spk_enc = spk_enc[0]
+ assert len(asr_enc) == 1, len(asr_enc)
+ assert len(spk_enc) == 1, len(spk_enc)
+
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ for hyp in nbest_hyps:
+ assert isinstance(hyp, (HypothesisSAASR)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1: last_pos]
+ else:
+ token_int = hyp.yseq[1: last_pos].tolist()
+
+ spk_weigths = torch.stack(hyp.spk_weigths, dim=0)
+
+ token_ori = self.converter.ids2tokens(token_int)
+ text_ori = self.tokenizer.tokens2text(token_ori)
+
+ text_ori_spklist = text_ori.split('$')
+ cur_index = 0
+ spk_choose = []
+ for i in range(len(text_ori_spklist)):
+ text_ori_split = text_ori_spklist[i]
+ n = len(text_ori_split)
+ spk_weights_local = spk_weigths[cur_index: cur_index + n]
+ cur_index = cur_index + n + 1
+ spk_weights_local = spk_weights_local.mean(dim=0)
+ spk_choose_local = spk_weights_local.argmax(-1)
+ spk_choose.append(spk_choose_local.item() + 1)
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != 0, token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+
+ text_spklist = text.split('$')
+ assert len(spk_choose) == len(text_spklist)
+
+ spk_list = []
+ for i in range(len(text_spklist)):
+ text_split = text_spklist[i]
+ n = len(text_split)
+ spk_list.append(str(spk_choose[i]) * n)
+
+ text_id = '$'.join(spk_list)
+
+ assert len(text) == len(text_id)
+
+ results.append((text, text_id, token, token_int, hyp))
+
+ assert check_return_type(results)
+ return results
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
deleted file mode 100644
index a52e94a..0000000
--- a/funasr/bin/asr_inference.py
+++ /dev/null
@@ -1,655 +0,0 @@
-#!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
-from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
-from funasr.modules.beam_search.beam_search import BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tasks.asr import frontend_choices
-
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend=='wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
- else:
- frontend_class=frontend_choices.get_class(asr_train_args.frontend)
- frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, None, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, _ = self.asr_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
- assert len(enc) == 1, len(enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
- return results
-
-def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
- mc=mc,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- mc=mc,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
-
- logging.info("uttid: {}".format(key))
- logging.info("text predictions: {}\n".format(text))
- return asr_result_list
-
- return _forward
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 7b04a9e..dbbb3ed 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -12,13 +15,1630 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+#!/usr/bin/env python3
+import argparse
+import logging
+import sys
+import time
+import copy
+import os
+import codecs
+import tempfile
+import requests
+from pathlib import Path
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+import yaml
+import numpy as np
+import torch
+import torchaudio
+from typeguard import check_argument_types
+from typeguard import check_return_type
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.beam_search import BeamSearch
+# from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
+from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+
+
+from funasr.utils.vad_utils import slice_padding_fbank
+from funasr.tasks.vad import VADTask
+from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
+from funasr.bin.asr_infer import Speech2Text
+from funasr.bin.asr_infer import Speech2TextParaformer, Speech2TextParaformerOnline
+from funasr.bin.asr_infer import Speech2TextUniASR
+from funasr.bin.asr_infer import Speech2TextMFCCA
+from funasr.bin.vad_infer import Speech2VadSegment
+from funasr.bin.punc_infer import Text2Punc
+from funasr.bin.tp_infer import Speech2Timestamp
+from funasr.bin.asr_infer import Speech2TextTransducer
+from funasr.bin.asr_infer import Speech2TextSAASR
+
+def inference_asr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2Text(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}\n".format(text))
+ return asr_result_list
+
+ return _forward
+
+
+def inference_paraformer(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ output_dir: Optional[str] = None,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ export_mode = False
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ export_mode = param_dict.get("export_mode", False)
+ else:
+ hotword_list_or_file = None
+
+ if kwargs.get("device", None) == "cpu":
+ ngpu = 0
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ batch_size = 1
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
+ )
+
+ speech2text = Speech2TextParaformer(**speech2text_kwargs)
+
+ if timestamp_model_file is not None:
+ speechtext2timestamp = Speech2Timestamp(
+ timestamp_cmvn_file=cmvn_file,
+ timestamp_model_file=timestamp_model_file,
+ timestamp_infer_config=timestamp_infer_config,
+ )
+ else:
+ speechtext2timestamp = None
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ if 'hotword' in kwargs and kwargs['hotword'] is not None:
+ hotword_list_or_file = kwargs['hotword']
+ if hotword_list_or_file is not None or 'hotword' in kwargs:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
+
+ forward_time_total = 0.0
+ length_total = 0.0
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+
+ logging.info("decoding, utt_id: {}".format(keys))
+ # N-best list of (text, token, token_int, hyp_object)
+
+ time_beg = time.time()
+ results = speech2text(**batch)
+ if len(results) < 1:
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
+ time_end = time.time()
+ forward_time = time_end - time_beg
+ lfr_factor = results[0][-1]
+ length = results[0][-2]
+ forward_time_total += forward_time
+ length_total += length
+ rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time,
+ 100 * forward_time / (
+ length * lfr_factor))
+ logging.info(rtf_cur)
+
+ for batch_id in range(_bs):
+ result = [results[batch_id][:-2]]
+
+ key = keys[batch_id]
+ for n, result in zip(range(1, nbest + 1), result):
+ text, token, token_int, hyp = result[0], result[1], result[2], result[3]
+ timestamp = result[4] if len(result[4]) > 0 else None
+ # conduct timestamp prediction here
+ # timestamp inference requires token length
+ # thus following inference cannot be conducted in batch
+ if timestamp is None and speechtext2timestamp:
+ ts_batch = {}
+ ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
+ ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
+ ts_batch['text_lengths'] = torch.tensor([len(token)])
+ us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
+ ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token,
+ force_time_shift=-3.0)
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["rtf"][key] = rtf_cur
+
+ if text is not None:
+ if use_timestamp and timestamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
+ timestamp_postprocessed = ""
+ if len(postprocessed_result) == 3:
+ text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
+ postprocessed_result[1], \
+ postprocessed_result[2]
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+ item = {'key': key, 'value': text_postprocessed}
+ if timestamp_postprocessed != "":
+ item['timestamp'] = timestamp_postprocessed
+ asr_result_list.append(item)
+ finish_count += 1
+ # asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = " ".join(word_lists)
+
+ logging.info("decoding, utt: {}, predictions: {}".format(key, text))
+ rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total,
+ forward_time_total,
+ 100 * forward_time_total / (
+ length_total * lfr_factor))
+ logging.info(rtf_avg)
+ if writer is not None:
+ ibest_writer["rtf"]["rtf_avf"] = rtf_avg
+ return asr_result_list
+
+ return _forward
+
+
+def inference_paraformer_vad_punc(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ vad_infer_config: Optional[str] = None,
+ vad_model_file: Optional[str] = None,
+ vad_cmvn_file: Optional[str] = None,
+ time_stamp_writer: bool = True,
+ punc_infer_config: Optional[str] = None,
+ punc_model_file: Optional[str] = None,
+ outputs_dict: Optional[bool] = True,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ else:
+ hotword_list_or_file = None
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2vadsegment
+ speech2vadsegment_kwargs = dict(
+ vad_infer_config=vad_infer_config,
+ vad_model_file=vad_model_file,
+ vad_cmvn_file=vad_cmvn_file,
+ device=device,
+ dtype=dtype,
+ )
+ # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
+ speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
+
+ # 3. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
+ )
+ speech2text = Speech2TextParaformer(**speech2text_kwargs)
+ text2punc = None
+ if punc_model_file is not None:
+ text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
+
+ if output_dir is not None:
+ writer = DatadirWriter(output_dir)
+ ibest_writer = writer[f"1best_recog"]
+ ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+
+ if speech2text.hotword_list is None:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=1,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
+ collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
+
+ finish_count = 0
+ file_count = 1
+ lfr_factor = 6
+ # 7 .Start for-loop
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ writer = None
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ ibest_writer = writer[f"1best_recog"]
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ vad_results = speech2vadsegment(**batch)
+ _, vadsegments = vad_results[0], vad_results[1][0]
+
+ speech, speech_lengths = batch["speech"], batch["speech_lengths"]
+
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+ for j, beg_idx in enumerate(range(0, n, batch_size)):
+ end_idx = min(n, beg_idx + batch_size)
+ speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
+
+ batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
+ batch = to_device(batch, device=device)
+ results = speech2text(**batch)
+
+ if len(results) < 1:
+ results = [["", [], [], [], [], [], []]]
+ results_sorted.extend(results)
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ restored_data[index] = results_sorted[j]
+ result = ["", [], [], [], [], [], []]
+ for j in range(n):
+ result[0] += restored_data[j][0]
+ result[1] += restored_data[j][1]
+ result[2] += restored_data[j][2]
+ if len(restored_data[j][4]) > 0:
+ for t in restored_data[j][4]:
+ t[0] += vadsegments[j][0]
+ t[1] += vadsegments[j][0]
+ result[4] += restored_data[j][4]
+ # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
+
+ key = keys[0]
+ # result = result_segments[0]
+ text, token, token_int = result[0], result[1], result[2]
+ time_stamp = result[4] if len(result[4]) > 0 else None
+
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed = ""
+ time_stamp_postprocessed = ""
+ text_postprocessed_punc = postprocessed_result
+ if len(postprocessed_result) == 3:
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
+ postprocessed_result[1], \
+ postprocessed_result[2]
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+
+ text_postprocessed_punc = text_postprocessed
+ punc_id_list = []
+ if len(word_lists) > 0 and text2punc is not None:
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+
+ item = {'key': key, 'value': text_postprocessed_punc}
+ if text_postprocessed != "":
+ item['text_postprocessed'] = text_postprocessed
+ if time_stamp_postprocessed != "":
+ item['time_stamp'] = time_stamp_postprocessed
+
+ item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
+
+ asr_result_list.append(item)
+ finish_count += 1
+ # asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["vad"][key] = "{}".format(vadsegments)
+ ibest_writer["text"][key] = " ".join(word_lists)
+ ibest_writer["text_with_punc"][key] = text_postprocessed_punc
+ if time_stamp_postprocessed is not None:
+ ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
+
+ logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
+ return asr_result_list
+
+ return _forward
+
+def inference_paraformer_online(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ export_mode = False
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ batch_size = 1
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ )
+
+ speech2text = Speech2TextParaformerOnline(**speech2text_kwargs)
+
+ def _load_bytes(input):
+ middle_data = np.frombuffer(input, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in 'iu':
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype('float32')
+ if dtype.kind != 'f':
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2 ** (i.bits - 1)
+ offset = i.min + abs_max
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+ return array
+
+ def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
+ if not Path(yaml_path).exists():
+ raise FileExistsError(f'The {yaml_path} does not exist.')
+
+ with open(str(yaml_path), 'rb') as f:
+ data = yaml.load(f, Loader=yaml.Loader)
+ return data
+
+ def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
+ if len(cache) > 0:
+ return cache
+ config = _read_yaml(asr_train_config)
+ enc_output_size = config["encoder_conf"]["output_size"]
+ feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+ "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
+ cache["encoder"] = cache_en
+
+ cache_de = {"decode_fsmn": None}
+ cache["decoder"] = cache_de
+
+ return cache
+
+ def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
+ if len(cache) > 0:
+ config = _read_yaml(asr_train_config)
+ enc_output_size = config["encoder_conf"]["output_size"]
+ feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+ "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
+ cache["encoder"] = cache_en
+
+ cache_de = {"decode_fsmn": None}
+ cache["decoder"] = cache_de
+
+ return cache
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
+ raw_inputs = _load_bytes(data_path_and_name_and_type[0])
+ raw_inputs = torch.tensor(raw_inputs)
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, np.ndarray):
+ raw_inputs = torch.tensor(raw_inputs)
+ is_final = False
+ cache = {}
+ chunk_size = [5, 10, 5]
+ if param_dict is not None and "cache" in param_dict:
+ cache = param_dict["cache"]
+ if param_dict is not None and "is_final" in param_dict:
+ is_final = param_dict["is_final"]
+ if param_dict is not None and "chunk_size" in param_dict:
+ chunk_size = param_dict["chunk_size"]
+
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
+ asr_result_list = []
+ cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
+ item = {}
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ sample_offset = 0
+ speech_length = raw_inputs.shape[1]
+ stride_size = chunk_size[1] * 960
+ cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
+ final_result = ""
+ for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+ if sample_offset + stride_size >= speech_length - 1:
+ stride_size = speech_length - sample_offset
+ cache["encoder"]["is_final"] = True
+ else:
+ cache["encoder"]["is_final"] = False
+ input_lens = torch.tensor([stride_size])
+ asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
+ if len(asr_result) != 0:
+ final_result += " ".join(asr_result) + " "
+ item = {'key': "utt", 'value': final_result.strip()}
+ else:
+ input_lens = torch.tensor([raw_inputs.shape[1]])
+ cache["encoder"]["is_final"] = is_final
+ asr_result = speech2text(cache, raw_inputs, input_lens)
+ item = {'key': "utt", 'value': " ".join(asr_result)}
+
+ asr_result_list.append(item)
+ if is_final:
+ cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
+ return asr_result_list
+
+ return _forward
+
+
+def inference_uniasr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ ngram_file: Optional[str] = None,
+ cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ token_num_relax: int = 1,
+ decoding_ind: int = 0,
+ decoding_mode: str = "model1",
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ if param_dict is not None and "decoding_model" in param_dict:
+ if param_dict["decoding_model"] == "fast":
+ decoding_ind = 0
+ decoding_mode = "model1"
+ elif param_dict["decoding_model"] == "normal":
+ decoding_ind = 0
+ decoding_mode = "model2"
+ elif param_dict["decoding_model"] == "offline":
+ decoding_ind = 1
+ decoding_mode = "model2"
+ else:
+ raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ ngram_file=ngram_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ token_num_relax=token_num_relax,
+ decoding_ind=decoding_ind,
+ decoding_mode=decoding_mode,
+ )
+ speech2text = Speech2TextUniASR(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ logging.info(f"Utterance: {key}")
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = " ".join(word_lists)
+ return asr_result_list
+
+ return _forward
+
+
+def inference_mfcca(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2TextMFCCA(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ fs=fs,
+ mc=True,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+ return asr_result_list
+
+ return _forward
+
+def inference_transducer(
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
+) -> None:
+ """Transducer model inference.
+ Args:
+ output_dir: Output directory path.
+ batch_size: Batch decoding size.
+ dtype: Data type.
+ beam_size: Beam size.
+ ngpu: Number of GPUs.
+ seed: Random number generator seed.
+ lm_weight: Weight of language model.
+ nbest: Number of final hypothesis.
+ num_workers: Number of workers.
+ log_level: Level of verbose for logs.
+ data_path_and_name_and_type:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model path.
+ model_tag: Model tag.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ key_file: File key.
+ allow_variable_data_keys: Whether to allow variable data keys.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
+ assert check_argument_types()
+
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1:
+ device = "cuda"
+ else:
+ device = "cpu"
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ beam_search_config=beam_search_config,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ dtype=dtype,
+ beam_size=beam_size,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ quantize_asr_model=quantize_asr_model,
+ quantize_modules=quantize_modules,
+ quantize_dtype=quantize_dtype,
+ streaming=streaming,
+ simu_streaming=simu_streaming,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+ speech2text = Speech2TextTransducer.from_pretrained(
+ model_tag=model_tag,
+ **speech2text_kwargs,
+ )
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(
+ speech2text.asr_train_args, False
+ ),
+ collate_fn=ASRTask.build_collate_fn(
+ speech2text.asr_train_args, False
+ ),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4 .Start for-loop
+ with DatadirWriter(output_dir) as writer:
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ assert len(batch.keys()) == 1
+
+ try:
+ if speech2text.streaming:
+ speech = batch["speech"]
+
+ _steps = len(speech) // speech2text._ctx
+ _end = 0
+ for i in range(_steps):
+ _end = (i + 1) * speech2text._ctx
+
+ speech2text.streaming_decode(
+ speech[i * speech2text._ctx : _end], is_final=False
+ )
+
+ final_hyps = speech2text.streaming_decode(
+ speech[_end : len(speech)], is_final=True
+ )
+ elif speech2text.simu_streaming:
+ final_hyps = speech2text.simu_streaming_decode(**batch)
+ else:
+ final_hyps = speech2text(**batch)
+
+ results = speech2text.hypotheses_to_results(final_hyps)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ ibest_writer = writer[f"{n}best_recog"]
+
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+
+ if text is not None:
+ ibest_writer["text"][key] = text
+
+
+ return _forward
+
+
+def inference_sa_asr(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ streaming: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ mc: bool = False,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ streaming=streaming,
+ )
+ logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+ speech2text = Speech2TextSAASR(**speech2text_kwargs)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ mc=mc,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+ collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ # N-best list of (text, token, token_int, hyp_object)
+ try:
+ results = speech2text(**batch)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+ results = [[" ", ["sil"], [2], hyp]] * nbest
+
+ # Only supporting batch_size==1
+ key = keys[0]
+ for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ # Create a directory: outdir/{n}best_recog
+ if writer is not None:
+ ibest_writer = writer[f"{n}best_recog"]
+
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
+ ibest_writer["text_id"][key] = text_id
+
+ if text is not None:
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ item = {'key': key, 'value': text_postprocessed}
+ asr_result_list.append(item)
+ finish_count += 1
+ asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ ibest_writer["text"][key] = text
+
+ logging.info("uttid: {}".format(key))
+ logging.info("text predictions: {}".format(text))
+ logging.info("text_id predictions: {}\n".format(text_id))
+ return asr_result_list
+
+ return _forward
+
+
+def inference_launch(**kwargs):
+ if 'mode' in kwargs:
+ mode = kwargs['mode']
+ else:
+ logging.info("Unknown decoding mode.")
+ return None
+ if mode == "asr":
+ return inference_asr(**kwargs)
+ elif mode == "uniasr":
+ return inference_uniasr(**kwargs)
+ elif mode == "paraformer":
+ return inference_paraformer(**kwargs)
+ elif mode == "paraformer_streaming":
+ return inference_paraformer_online(**kwargs)
+ elif mode.startswith("paraformer_vad"):
+ return inference_paraformer_vad_punc(**kwargs)
+ elif mode == "mfcca":
+ return inference_mfcca(**kwargs)
+ elif mode == "rnnt":
+ return inference_transducer(**kwargs)
+ elif mode == "sa_asr":
+ return inference_sa_asr(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-
+
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
@@ -28,7 +1648,7 @@
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
-
+
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
@@ -61,7 +1681,7 @@
default=1,
help="The number of workers used for DataLoader",
)
-
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
@@ -70,14 +1690,20 @@
action="append",
)
group.add_argument("--key_file", type=str_or_none)
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group.add_argument(
- "--mc",
- type=bool,
- default=False,
- help="MultiChannel input",
- )
-
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
@@ -140,7 +1766,7 @@
default={},
help="The keyword arguments for transducer beam search.",
)
-
+
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
@@ -186,8 +1812,8 @@
type=bool,
default=False,
help="Whether to display partial hypotheses during chunk-by-chunk inference.",
- )
-
+ )
+
group = parser.add_argument_group("Dynamic quantization related")
group.add_argument(
"--quantize_asr_model",
@@ -211,8 +1837,8 @@
default="qint8",
choices=["float16", "qint8"],
help="Dtype for dynamic quantization.",
- )
-
+ )
+
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
@@ -239,72 +1865,6 @@
help="CTC weight in joint decoding",
)
return parser
-
-
-
-def inference_launch(**kwargs):
- if 'mode' in kwargs:
- mode = kwargs['mode']
- else:
- logging.info("Unknown decoding mode.")
- return None
- if mode == "asr":
- from funasr.bin.asr_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "uniasr":
- from funasr.bin.asr_inference_uniasr import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "paraformer":
- from funasr.bin.asr_inference_paraformer import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "paraformer_streaming":
- from funasr.bin.asr_inference_paraformer_streaming import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode.startswith("paraformer_vad"):
- from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc
- return inference_modelscope_vad_punc(**kwargs)
- elif mode == "mfcca":
- from funasr.bin.asr_inference_mfcca import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "rnnt":
- from funasr.bin.asr_inference_rnnt import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-def inference_launch_funasr(**kwargs):
- if 'mode' in kwargs:
- mode = kwargs['mode']
- else:
- logging.info("Unknown decoding mode.")
- return None
- if mode == "asr":
- from funasr.bin.asr_inference import inference
- return inference(**kwargs)
- elif mode == "sa_asr":
- from funasr.bin.sa_asr_inference import inference
- return inference(**kwargs)
- elif mode == "uniasr":
- from funasr.bin.asr_inference_uniasr import inference
- return inference(**kwargs)
- elif mode == "paraformer":
- from funasr.bin.asr_inference_paraformer import inference_modelscope
- inference_pipeline = inference_modelscope(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
- elif mode.startswith("paraformer_vad"):
- from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc
- inference_pipeline = inference_modelscope_vad_punc(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
- elif mode == "mfcca":
- from funasr.bin.asr_inference_mfcca import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "rnnt":
- from funasr.bin.asr_inference_rnnt import inference
- return inference(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
@@ -334,7 +1894,9 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- inference_launch_funasr(**kwargs)
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
+
if __name__ == "__main__":
diff --git a/funasr/bin/asr_inference_mfcca.py b/funasr/bin/asr_inference_mfcca.py
deleted file mode 100644
index e832869..0000000
--- a/funasr/bin/asr_inference_mfcca.py
+++ /dev/null
@@ -1,767 +0,0 @@
-#!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
-from funasr.modules.beam_search.beam_search import BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-import pdb
-
-
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
- 'audio_fs': 16000,
- 'model_fs': 16000
-}
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
- )
- lm.to(device)
- scorers["lm"] = lm.lm
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
- #beam_search.__class__ = BatchBeamSearch
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if(speech.dim()==3):
- speech = torch.squeeze(speech, 2)
- #speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- speech = speech.to(getattr(torch, self.dtype))
- # lenghts: (1,)
- lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- batch = {"speech": speech, "speech_lengths": lengths}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, _ = self.asr_model.encode(**batch)
-
- assert len(enc) == 1, len(enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
- return results
-
-
-# def inference(
-# maxlenratio: float,
-# minlenratio: float,
-# batch_size: int,
-# beam_size: int,
-# ngpu: int,
-# ctc_weight: float,
-# lm_weight: float,
-# penalty: float,
-# log_level: Union[int, str],
-# data_path_and_name_and_type,
-# asr_train_config: Optional[str],
-# asr_model_file: Optional[str],
-# cmvn_file: Optional[str] = None,
-# lm_train_config: Optional[str] = None,
-# lm_file: Optional[str] = None,
-# token_type: Optional[str] = None,
-# key_file: Optional[str] = None,
-# word_lm_train_config: Optional[str] = None,
-# bpemodel: Optional[str] = None,
-# allow_variable_data_keys: bool = False,
-# streaming: bool = False,
-# output_dir: Optional[str] = None,
-# dtype: str = "float32",
-# seed: int = 0,
-# ngram_weight: float = 0.9,
-# nbest: int = 1,
-# num_workers: int = 1,
-# **kwargs,
-# ):
-# assert check_argument_types()
-# if batch_size > 1:
-# raise NotImplementedError("batch decoding is not implemented")
-# if word_lm_train_config is not None:
-# raise NotImplementedError("Word LM is not implemented")
-# if ngpu > 1:
-# raise NotImplementedError("only single GPU decoding is supported")
-#
-# logging.basicConfig(
-# level=log_level,
-# format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-# )
-#
-# if ngpu >= 1 and torch.cuda.is_available():
-# device = "cuda"
-# else:
-# device = "cpu"
-#
-# # 1. Set random-seed
-# set_all_random_seed(seed)
-#
-# # 2. Build speech2text
-# speech2text_kwargs = dict(
-# asr_train_config=asr_train_config,
-# asr_model_file=asr_model_file,
-# cmvn_file=cmvn_file,
-# lm_train_config=lm_train_config,
-# lm_file=lm_file,
-# token_type=token_type,
-# bpemodel=bpemodel,
-# device=device,
-# maxlenratio=maxlenratio,
-# minlenratio=minlenratio,
-# dtype=dtype,
-# beam_size=beam_size,
-# ctc_weight=ctc_weight,
-# lm_weight=lm_weight,
-# ngram_weight=ngram_weight,
-# penalty=penalty,
-# nbest=nbest,
-# streaming=streaming,
-# )
-# logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
-# speech2text = Speech2Text(**speech2text_kwargs)
-#
-# # 3. Build data-iterator
-# loader = ASRTask.build_streaming_iterator(
-# data_path_and_name_and_type,
-# dtype=dtype,
-# batch_size=batch_size,
-# key_file=key_file,
-# num_workers=num_workers,
-# preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-# collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-# allow_variable_data_keys=allow_variable_data_keys,
-# inference=True,
-# )
-#
-# finish_count = 0
-# file_count = 1
-# # 7 .Start for-loop
-# # FIXME(kamo): The output format should be discussed about
-# asr_result_list = []
-# if output_dir is not None:
-# writer = DatadirWriter(output_dir)
-# else:
-# writer = None
-#
-# for keys, batch in loader:
-# assert isinstance(batch, dict), type(batch)
-# assert all(isinstance(s, str) for s in keys), keys
-# _bs = len(next(iter(batch.values())))
-# assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-# #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-#
-# # N-best list of (text, token, token_int, hyp_object)
-# try:
-# results = speech2text(**batch)
-# except TooShortUttError as e:
-# logging.warning(f"Utterance {keys} {e}")
-# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-# results = [[" ", ["<space>"], [2], hyp]] * nbest
-#
-# # Only supporting batch_size==1
-# key = keys[0]
-# for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-# # Create a directory: outdir/{n}best_recog
-# if writer is not None:
-# ibest_writer = writer[f"{n}best_recog"]
-#
-# # Write the result to each file
-# ibest_writer["token"][key] = " ".join(token)
-# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-# ibest_writer["score"][key] = str(hyp.score)
-#
-# if text is not None:
-# text_postprocessed = postprocess_utils.sentence_postprocess(token)
-# item = {'key': key, 'value': text_postprocessed}
-# asr_result_list.append(item)
-# finish_count += 1
-# asr_utils.print_progress(finish_count / file_count)
-# if writer is not None:
-# ibest_writer["text"][key] = text
-# return asr_result_list
-
-def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- fs=fs,
- mc=True,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["<space>"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
- return asr_result_list
-
- return _forward
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
deleted file mode 100644
index ecdb62a..0000000
--- a/funasr/bin/asr_inference_paraformer.py
+++ /dev/null
@@ -1,1027 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
-import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
-from pathlib import Path
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.bin.tp_inference import SpeechText2Timestamp
-from funasr.bin.vad_inference import Speech2VadSegment
-from funasr.bin.punctuation_infer import Text2Punc
-from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
- # 6. [Optional] Build hotword list from str, local file or url
- self.hotword_list = None
- self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
-
- is_use_lm = lm_weight != 0.0 and lm_file is not None
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
- beam_search = None
- self.beam_search = beam_search
- logging.info(f"Beam_search: {self.beam_search}")
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- begin_time: int = 0, end_time: int = None,
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, enc_len = self.asr_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
- # assert len(enc) == 1, len(enc)
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
- predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, NeatContextualParaformer):
- if self.hotword_list:
- logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- if isinstance(self.asr_model, BiCifParaformer):
- _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
- pre_token_length) # test no bias cif2
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = enc[i, :enc_len[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- timestamp = []
- if isinstance(self.asr_model, BiCifParaformer):
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
- us_peaks[i][:enc_len[i]*3],
- copy.copy(token),
- vad_offset=begin_time)
- results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
-
-
- # assert check_return_type(results)
- return results
-
- def generate_hotwords_list(self, hotword_list_or_file):
- # for None
- if hotword_list_or_file is None:
- hotword_list = None
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords from local txt...")
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- elif hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for text str input
- elif not hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords as str...")
- hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- else:
- hotword_list = None
- return hotword_list
-
-
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- export_mode = False
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- export_mode = param_dict.get("export_mode", False)
- else:
- hotword_list_or_file = None
-
- if kwargs.get("device", None) == "cpu":
- ngpu = 0
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
- )
-
- speech2text = Speech2Text(**speech2text_kwargs)
-
- if timestamp_model_file is not None:
- speechtext2timestamp = SpeechText2Timestamp(
- timestamp_cmvn_file=cmvn_file,
- timestamp_model_file=timestamp_model_file,
- timestamp_infer_config=timestamp_infer_config,
- )
- else:
- speechtext2timestamp = None
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- if 'hotword' in kwargs and kwargs['hotword'] is not None:
- hotword_list_or_file = kwargs['hotword']
- if hotword_list_or_file is not None or 'hotword' in kwargs:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- if param_dict is not None:
- use_timestamp = param_dict.get('use_timestamp', True)
- else:
- use_timestamp = True
-
- forward_time_total = 0.0
- length_total = 0.0
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
- logging.info("decoding, utt_id: {}".format(keys))
- # N-best list of (text, token, token_int, hyp_object)
-
- time_beg = time.time()
- results = speech2text(**batch)
- if len(results) < 1:
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
- time_end = time.time()
- forward_time = time_end - time_beg
- lfr_factor = results[0][-1]
- length = results[0][-2]
- forward_time_total += forward_time
- length_total += length
- rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
- logging.info(rtf_cur)
-
- for batch_id in range(_bs):
- result = [results[batch_id][:-2]]
-
- key = keys[batch_id]
- for n, result in zip(range(1, nbest + 1), result):
- text, token, token_int, hyp = result[0], result[1], result[2], result[3]
- timestamp = result[4] if len(result[4]) > 0 else None
- # conduct timestamp prediction here
- # timestamp inference requires token length
- # thus following inference cannot be conducted in batch
- if timestamp is None and speechtext2timestamp:
- ts_batch = {}
- ts_batch['speech'] = batch['speech'][batch_id].unsqueeze(0)
- ts_batch['speech_lengths'] = torch.tensor([batch['speech_lengths'][batch_id]])
- ts_batch['text_lengths'] = torch.tensor([len(token)])
- us_alphas, us_peaks = speechtext2timestamp(**ts_batch)
- ts_str, timestamp = ts_prediction_lfr6_standard(us_alphas[0], us_peaks[0], token, force_time_shift=-3.0)
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["rtf"][key] = rtf_cur
-
- if text is not None:
- if use_timestamp and timestamp is not None:
- postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
- else:
- postprocessed_result = postprocess_utils.sentence_postprocess(token)
- timestamp_postprocessed = ""
- if len(postprocessed_result) == 3:
- text_postprocessed, timestamp_postprocessed, word_lists = postprocessed_result[0], \
- postprocessed_result[1], \
- postprocessed_result[2]
- else:
- text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
- item = {'key': key, 'value': text_postprocessed}
- if timestamp_postprocessed != "":
- item['timestamp'] = timestamp_postprocessed
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text))
- rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
- logging.info(rtf_avg)
- if writer is not None:
- ibest_writer["rtf"]["rtf_avf"] = rtf_avg
- return asr_result_list
-
- return _forward
-
-
-def inference_modelscope_vad_punc(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- vad_infer_config: Optional[str] = None,
- vad_model_file: Optional[str] = None,
- vad_cmvn_file: Optional[str] = None,
- time_stamp_writer: bool = True,
- punc_infer_config: Optional[str] = None,
- punc_model_file: Optional[str] = None,
- outputs_dict: Optional[bool] = True,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- else:
- hotword_list_or_file = None
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
- # 3. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
- )
- speech2text = Speech2Text(**speech2text_kwargs)
- text2punc = None
- if punc_model_file is not None:
- text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
-
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- ibest_writer = writer[f"1best_recog"]
- ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
-
- if 'hotword' in kwargs:
- hotword_list_or_file = kwargs['hotword']
-
- if speech2text.hotword_list is None:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=1,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- if param_dict is not None:
- use_timestamp = param_dict.get('use_timestamp', True)
- else:
- use_timestamp = True
-
- finish_count = 0
- file_count = 1
- lfr_factor = 6
- # 7 .Start for-loop
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- writer = None
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- vad_results = speech2vadsegment(**batch)
- _, vadsegments = vad_results[0], vad_results[1][0]
-
- speech, speech_lengths = batch["speech"], batch["speech_lengths"]
-
- n = len(vadsegments)
- data_with_index = [(vadsegments[i], i) for i in range(n)]
- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
- results_sorted = []
- for j, beg_idx in enumerate(range(0, n, batch_size)):
- end_idx = min(n, beg_idx + batch_size)
- speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
-
- batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
- batch = to_device(batch, device=device)
- results = speech2text(**batch)
-
- if len(results) < 1:
- results = [["", [], [], [], [], [], []]]
- results_sorted.extend(results)
- restored_data = [0] * n
- for j in range(n):
- index = sorted_data[j][1]
- restored_data[index] = results_sorted[j]
- result = ["", [], [], [], [], [], []]
- for j in range(n):
- result[0] += restored_data[j][0]
- result[1] += restored_data[j][1]
- result[2] += restored_data[j][2]
- if len(restored_data[j][4]) > 0:
- for t in restored_data[j][4]:
- t[0] += vadsegments[j][0]
- t[1] += vadsegments[j][0]
- result[4] += restored_data[j][4]
- # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
-
- key = keys[0]
- # result = result_segments[0]
- text, token, token_int = result[0], result[1], result[2]
- time_stamp = result[4] if len(result[4]) > 0 else None
-
- if use_timestamp and time_stamp is not None:
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
- else:
- postprocessed_result = postprocess_utils.sentence_postprocess(token)
- text_postprocessed = ""
- time_stamp_postprocessed = ""
- text_postprocessed_punc = postprocessed_result
- if len(postprocessed_result) == 3:
- text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
- postprocessed_result[1], \
- postprocessed_result[2]
- else:
- text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
-
- text_postprocessed_punc = text_postprocessed
- punc_id_list = []
- if len(word_lists) > 0 and text2punc is not None:
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
-
- item = {'key': key, 'value': text_postprocessed_punc}
- if text_postprocessed != "":
- item['text_postprocessed'] = text_postprocessed
- if time_stamp_postprocessed != "":
- item['time_stamp'] = time_stamp_postprocessed
-
- item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
-
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["vad"][key] = "{}".format(vadsegments)
- ibest_writer["text"][key] = " ".join(word_lists)
- ibest_writer["text_with_punc"][key] = text_postprocessed_punc
- if time_stamp_postprocessed is not None:
- ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
-
- logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
- return asr_result_list
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--hotword",
- type=str_or_none,
- default=None,
- help="hotword file path or hotwords seperated by space"
- )
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group.add_argument(
- "--frontend_conf",
- default=None,
- help="",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- param_dict = {'hotword': args.hotword}
- kwargs = vars(args)
- kwargs.pop("config", None)
- kwargs['param_dict'] = param_dict
- inference_pipeline = inference_modelscope(**kwargs)
- return inference_pipeline(kwargs["data_path_and_name_and_type"], param_dict=param_dict)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
deleted file mode 100644
index 4f04d02..0000000
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ /dev/null
@@ -1,749 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
-import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
-import yaml
-from pathlib import Path
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-import torchaudio
-from typeguard import check_argument_types
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-
-np.set_printoptions(threshold=np.inf)
-
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
- # 6. [Optional] Build hotword list from str, local file or url
-
- is_use_lm = lm_weight != 0.0 and lm_file is not None
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
- beam_search = None
- self.beam_search = beam_search
- logging.info(f"Beam_search: {self.beam_search}")
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
- results = []
- cache_en = cache["encoder"]
- if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
- if cache_en["start_idx"] == 0:
- return []
- cache_en["tail_chunk"] = True
- feats = cache_en["feats"]
- feats_len = torch.tensor([feats.shape[1]])
- self.asr_model.frontend = None
- results = self.infer(feats, feats_len, cache)
- return results
- else:
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
-
- if feats.shape[1] != 0:
- if cache_en["is_final"]:
- if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
- cache_en["last_chunk"] = True
- else:
- # first chunk
- feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
- feats_len = torch.tensor([feats_chunk1.shape[1]])
- results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
-
- # last chunk
- cache_en["last_chunk"] = True
- feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
- feats_len = torch.tensor([feats_chunk2.shape[1]])
- results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
-
- return [" ".join(results_chunk1 + results_chunk2)]
-
- results = self.infer(feats, feats_len, cache)
-
- return results
-
- @torch.no_grad()
- def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None):
- batch = {"speech": feats, "speech_lengths": feats_len}
- batch = to_device(batch, device=self.device)
- # b. Forward Encoder
- enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache)
- if isinstance(enc, tuple):
- enc = enc[0]
- # assert len(enc) == 1, len(enc)
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
- predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
- pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1]
- if torch.max(pre_token_length) < 1:
- return []
- decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
- decoder_out = decoder_outs
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = enc[i, :enc_len[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
- token = " ".join(token)
-
- results.append(token)
-
- # assert check_return_type(results)
- return results
-
-
-def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
-
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
-
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
-
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- export_mode = False
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- )
-
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _load_bytes(input):
- middle_data = np.frombuffer(input, dtype=np.int16)
- middle_data = np.asarray(middle_data)
- if middle_data.dtype.kind not in 'iu':
- raise TypeError("'middle_data' must be an array of integers")
- dtype = np.dtype('float32')
- if dtype.kind != 'f':
- raise TypeError("'dtype' must be a floating point type")
-
- i = np.iinfo(middle_data.dtype)
- abs_max = 2 ** (i.bits - 1)
- offset = i.min + abs_max
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
- return array
-
- def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
- if not Path(yaml_path).exists():
- raise FileExistsError(f'The {yaml_path} does not exist.')
-
- with open(str(yaml_path), 'rb') as f:
- data = yaml.load(f, Loader=yaml.Loader)
- return data
-
- def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
- if len(cache) > 0:
- return cache
- config = _read_yaml(asr_train_config)
- enc_output_size = config["encoder_conf"]["output_size"]
- feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
- cache["encoder"] = cache_en
-
- cache_de = {"decode_fsmn": None}
- cache["decoder"] = cache_de
-
- return cache
-
- def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
- if len(cache) > 0:
- config = _read_yaml(asr_train_config)
- enc_output_size = config["encoder_conf"]["output_size"]
- feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
- cache["encoder"] = cache_en
-
- cache_de = {"decode_fsmn": None}
- cache["decoder"] = cache_de
-
- return cache
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
-
- # 3. Build data-iterator
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
- raw_inputs = _load_bytes(data_path_and_name_and_type[0])
- raw_inputs = torch.tensor(raw_inputs)
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, np.ndarray):
- raw_inputs = torch.tensor(raw_inputs)
- is_final = False
- cache = {}
- chunk_size = [5, 10, 5]
- if param_dict is not None and "cache" in param_dict:
- cache = param_dict["cache"]
- if param_dict is not None and "is_final" in param_dict:
- is_final = param_dict["is_final"]
- if param_dict is not None and "chunk_size" in param_dict:
- chunk_size = param_dict["chunk_size"]
-
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
- asr_result_list = []
- cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
- item = {}
- if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- sample_offset = 0
- speech_length = raw_inputs.shape[1]
- stride_size = chunk_size[1] * 960
- cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
- final_result = ""
- for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
- if sample_offset + stride_size >= speech_length - 1:
- stride_size = speech_length - sample_offset
- cache["encoder"]["is_final"] = True
- else:
- cache["encoder"]["is_final"] = False
- input_lens = torch.tensor([stride_size])
- asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
- if len(asr_result) != 0:
- final_result += " ".join(asr_result) + " "
- item = {'key': "utt", 'value': final_result.strip()}
- else:
- input_lens = torch.tensor([raw_inputs.shape[1]])
- cache["encoder"]["is_final"] = is_final
- asr_result = speech2text(cache, raw_inputs, input_lens)
- item = {'key': "utt", 'value': " ".join(asr_result)}
-
- asr_result_list.append(item)
- if is_final:
- cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
- return asr_result_list
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--hotword",
- type=str_or_none,
- default=None,
- help="hotword file path or hotwords seperated by space"
- )
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group.add_argument(
- "--frontend_conf",
- default=None,
- help="",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- param_dict = {'hotword': args.hotword}
- kwargs = vars(args)
- kwargs.pop("config", None)
- kwargs['param_dict'] = param_dict
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
-
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
deleted file mode 100644
index bd36907..0000000
--- a/funasr/bin/asr_inference_rnnt.py
+++ /dev/null
@@ -1,734 +0,0 @@
-#!/usr/bin/env python3
-
-""" Inference class definition for Transducer models."""
-
-from __future__ import annotations
-
-import argparse
-import logging
-import math
-import sys
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
-
-import numpy as np
-import torch
-from packaging.version import parse as V
-from typeguard import check_argument_types, check_return_type
-
-from funasr.modules.beam_search.beam_search_transducer import (
- BeamSearchTransducer,
- Hypothesis,
-)
-from funasr.modules.nets_utils import TooShortUttError
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.asr import ASRTransducerTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool, str2triple_str, str_or_none
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.models.frontend.wav_frontend import WavFrontend
-
-class Speech2Text:
- """Speech2Text class for Transducer models.
- Args:
- asr_train_config: ASR model training config path.
- asr_model_file: ASR model path.
- beam_search_config: Beam search config path.
- lm_train_config: Language Model training config path.
- lm_file: Language Model config path.
- token_type: Type of token units.
- bpemodel: BPE model path.
- device: Device to use for inference.
- beam_size: Size of beam during search.
- dtype: Data type.
- lm_weight: Language model weight.
- quantize_asr_model: Whether to apply dynamic quantization to ASR model.
- quantize_modules: List of module names to apply dynamic quantization on.
- quantize_dtype: Dynamic quantization data type.
- nbest: Number of final hypothesis.
- streaming: Whether to perform chunk-by-chunk inference.
- chunk_size: Number of frames in chunk AFTER subsampling.
- left_context: Number of frames in left context AFTER subsampling.
- right_context: Number of frames in right context AFTER subsampling.
- display_partial_hypotheses: Whether to display partial hypotheses.
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- beam_search_config: Dict[str, Any] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- beam_size: int = 5,
- dtype: str = "float32",
- lm_weight: float = 1.0,
- quantize_asr_model: bool = False,
- quantize_modules: List[str] = None,
- quantize_dtype: str = "qint8",
- nbest: int = 1,
- streaming: bool = False,
- simu_streaming: bool = False,
- chunk_size: int = 16,
- left_context: int = 32,
- right_context: int = 0,
- display_partial_hypotheses: bool = False,
- ) -> None:
- """Construct a Speech2Text object."""
- super().__init__()
-
- assert check_argument_types()
- asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
-
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- if quantize_asr_model:
- if quantize_modules is not None:
- if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
- raise ValueError(
- "Only 'Linear' and 'LSTM' modules are currently supported"
- " by PyTorch and in --quantize_modules"
- )
-
- q_config = set([getattr(torch.nn, q) for q in quantize_modules])
- else:
- q_config = {torch.nn.Linear}
-
- if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
- raise ValueError(
- "float16 dtype for dynamic quantization is not supported with torch"
- " version < 1.5.0. Switching to qint8 dtype instead."
- )
- q_dtype = getattr(torch, quantize_dtype)
-
- asr_model = torch.quantization.quantize_dynamic(
- asr_model, q_config, dtype=q_dtype
- ).eval()
- else:
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
- )
- lm_scorer = lm.lm
- else:
- lm_scorer = None
-
- # 4. Build BeamSearch object
- if beam_search_config is None:
- beam_search_config = {}
-
- beam_search = BeamSearchTransducer(
- asr_model.decoder,
- asr_model.joint_network,
- beam_size,
- lm=lm_scorer,
- lm_weight=lm_weight,
- nbest=nbest,
- **beam_search_config,
- )
-
- token_list = asr_model.token_list
-
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
-
- self.converter = converter
- self.tokenizer = tokenizer
-
- self.beam_search = beam_search
- self.streaming = streaming
- self.simu_streaming = simu_streaming
- self.chunk_size = max(chunk_size, 0)
- self.left_context = left_context
- self.right_context = max(right_context, 0)
-
- if not streaming or chunk_size == 0:
- self.streaming = False
- self.asr_model.encoder.dynamic_chunk_training = False
-
- if not simu_streaming or chunk_size == 0:
- self.simu_streaming = False
- self.asr_model.encoder.dynamic_chunk_training = False
-
- self.frontend = frontend
- self.window_size = self.chunk_size + self.right_context
-
- if self.streaming:
- self._ctx = self.asr_model.encoder.get_encoder_input_size(
- self.window_size
- )
-
- self.last_chunk_length = (
- self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
- )
- self.reset_inference_cache()
-
- def reset_inference_cache(self) -> None:
- """Reset Speech2Text parameters."""
- self.frontend_cache = None
-
- self.asr_model.encoder.reset_streaming_cache(
- self.left_context, device=self.device
- )
- self.beam_search.reset_inference_cache()
-
- self.num_processed_frames = torch.tensor([[0]], device=self.device)
-
- @torch.no_grad()
- def streaming_decode(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- is_final: bool = True,
- ) -> List[Hypothesis]:
- """Speech2Text streaming call.
- Args:
- speech: Chunk of speech data. (S)
- is_final: Whether speech corresponds to the final chunk of data.
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if is_final:
- if self.streaming and speech.size(0) < self.last_chunk_length:
- pad = torch.zeros(
- self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
- )
- speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final)
-
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- if self.asr_model.normalize is not None:
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
- enc_out = self.asr_model.encoder.chunk_forward(
- feats,
- feats_lengths,
- self.num_processed_frames,
- chunk_size=self.chunk_size,
- left_context=self.left_context,
- right_context=self.right_context,
- )
- nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
-
- self.num_processed_frames += self.chunk_size
-
- if is_final:
- self.reset_inference_cache()
-
- return nbest_hyps
-
- @torch.no_grad()
- def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
- """Speech2Text call.
- Args:
- speech: Speech data. (S)
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
- assert check_argument_types()
-
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- if self.asr_model.normalize is not None:
- feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
- enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context)
- nbest_hyps = self.beam_search(enc_out[0])
-
- return nbest_hyps
-
- @torch.no_grad()
- def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
- """Speech2Text call.
- Args:
- speech: Speech data. (S)
- Returns:
- nbest_hypothesis: N-best hypothesis.
- """
- assert check_argument_types()
-
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
-
- feats = to_device(feats, device=self.device)
- feats_lengths = to_device(feats_lengths, device=self.device)
-
- enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
-
- nbest_hyps = self.beam_search(enc_out[0])
-
- return nbest_hyps
-
- def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]:
- """Build partial or final results from the hypotheses.
- Args:
- nbest_hyps: N-best hypothesis.
- Returns:
- results: Results containing different representation for the hypothesis.
- """
- results = []
-
- for hyp in nbest_hyps:
- token_int = list(filter(lambda x: x != 0, hyp.yseq))
-
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
-
- return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ) -> Speech2Text:
- """Build Speech2Text instance from the pretrained model.
- Args:
- model_tag: Model tag of the pretrained models.
- Return:
- : Speech2Text instance.
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2Text(**kwargs)
-
-
-def inference(
- output_dir: str,
- batch_size: int,
- dtype: str,
- beam_size: int,
- ngpu: int,
- seed: int,
- lm_weight: float,
- nbest: int,
- num_workers: int,
- log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str],
- beam_search_config: Optional[dict],
- lm_train_config: Optional[str],
- lm_file: Optional[str],
- model_tag: Optional[str],
- token_type: Optional[str],
- bpemodel: Optional[str],
- key_file: Optional[str],
- allow_variable_data_keys: bool,
- quantize_asr_model: Optional[bool],
- quantize_modules: Optional[List[str]],
- quantize_dtype: Optional[str],
- streaming: Optional[bool],
- simu_streaming: Optional[bool],
- chunk_size: Optional[int],
- left_context: Optional[int],
- right_context: Optional[int],
- display_partial_hypotheses: bool,
- **kwargs,
-) -> None:
- """Transducer model inference.
- Args:
- output_dir: Output directory path.
- batch_size: Batch decoding size.
- dtype: Data type.
- beam_size: Beam size.
- ngpu: Number of GPUs.
- seed: Random number generator seed.
- lm_weight: Weight of language model.
- nbest: Number of final hypothesis.
- num_workers: Number of workers.
- log_level: Level of verbose for logs.
- data_path_and_name_and_type:
- asr_train_config: ASR model training config path.
- asr_model_file: ASR model path.
- beam_search_config: Beam search config path.
- lm_train_config: Language Model training config path.
- lm_file: Language Model path.
- model_tag: Model tag.
- token_type: Type of token units.
- bpemodel: BPE model path.
- key_file: File key.
- allow_variable_data_keys: Whether to allow variable data keys.
- quantize_asr_model: Whether to apply dynamic quantization to ASR model.
- quantize_modules: List of module names to apply dynamic quantization on.
- quantize_dtype: Dynamic quantization data type.
- streaming: Whether to perform chunk-by-chunk inference.
- chunk_size: Number of frames in chunk AFTER subsampling.
- left_context: Number of frames in left context AFTER subsampling.
- right_context: Number of frames in right context AFTER subsampling.
- display_partial_hypotheses: Whether to display partial hypotheses.
- """
- assert check_argument_types()
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1:
- device = "cuda"
- else:
- device = "cpu"
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- beam_search_config=beam_search_config,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- dtype=dtype,
- beam_size=beam_size,
- lm_weight=lm_weight,
- nbest=nbest,
- quantize_asr_model=quantize_asr_model,
- quantize_modules=quantize_modules,
- quantize_dtype=quantize_dtype,
- streaming=streaming,
- simu_streaming=simu_streaming,
- chunk_size=chunk_size,
- left_context=left_context,
- right_context=right_context,
- )
- speech2text = Speech2Text.from_pretrained(
- model_tag=model_tag,
- **speech2text_kwargs,
- )
-
- # 3. Build data-iterator
- loader = ASRTransducerTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTransducerTask.build_preprocess_fn(
- speech2text.asr_train_args, False
- ),
- collate_fn=ASRTransducerTask.build_collate_fn(
- speech2text.asr_train_args, False
- ),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 4 .Start for-loop
- with DatadirWriter(output_dir) as writer:
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
-
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
- assert len(batch.keys()) == 1
-
- try:
- if speech2text.streaming:
- speech = batch["speech"]
-
- _steps = len(speech) // speech2text._ctx
- _end = 0
- for i in range(_steps):
- _end = (i + 1) * speech2text._ctx
-
- speech2text.streaming_decode(
- speech[i * speech2text._ctx : _end], is_final=False
- )
-
- final_hyps = speech2text.streaming_decode(
- speech[_end : len(speech)], is_final=True
- )
- elif speech2text.simu_streaming:
- final_hyps = speech2text.simu_streaming_decode(**batch)
- else:
- final_hyps = speech2text(**batch)
-
- results = speech2text.hypotheses_to_results(final_hyps)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
- results = [[" ", ["<space>"], [2], hyp]] * nbest
-
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- ibest_writer = writer[f"{n}best_recog"]
-
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- ibest_writer["text"][key] = text
-
-
-def get_parser():
- """Get Transducer model inference parser."""
-
- parser = config_argparse.ArgumentParser(
- description="ASR Transducer Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=5, help="Beam size")
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument(
- "--beam_search_config",
- default={},
- help="The keyword arguments for transducer beam search.",
- )
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- group = parser.add_argument_group("Dynamic quantization related")
- parser.add_argument(
- "--quantize_asr_model",
- type=bool,
- default=False,
- help="Apply dynamic quantization to ASR model.",
- )
- parser.add_argument(
- "--quantize_modules",
- nargs="*",
- default=None,
- help="""Module names to apply dynamic quantization on.
- The module names are provided as a list, where each name is separated
- by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
- Each specified name should be an attribute of 'torch.nn', e.g.:
- torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
- )
- parser.add_argument(
- "--quantize_dtype",
- type=str,
- default="qint8",
- choices=["float16", "qint8"],
- help="Dtype for dynamic quantization.",
- )
-
- group = parser.add_argument_group("Streaming related")
- parser.add_argument(
- "--streaming",
- type=bool,
- default=False,
- help="Whether to perform chunk-by-chunk inference.",
- )
- parser.add_argument(
- "--simu_streaming",
- type=bool,
- default=False,
- help="Whether to simulate chunk-by-chunk inference.",
- )
- parser.add_argument(
- "--chunk_size",
- type=int,
- default=16,
- help="Number of frames in chunk AFTER subsampling.",
- )
- parser.add_argument(
- "--left_context",
- type=int,
- default=32,
- help="Number of frames in left context of the chunk AFTER subsampling.",
- )
- parser.add_argument(
- "--right_context",
- type=int,
- default=0,
- help="Number of frames in right context of the chunk AFTER subsampling.",
- )
- parser.add_argument(
- "--display_partial_hypotheses",
- type=bool,
- default=False,
- help="Whether to display partial hypotheses during chunk-by-chunk inference.",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
-
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
-
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
-
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
deleted file mode 100644
index 35ecdc2..0000000
--- a/funasr/bin/asr_inference_uniasr.py
+++ /dev/null
@@ -1,694 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearchScama as BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskUniASR as ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
-
-
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- frontend_conf: dict = None,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
- if decoding_mode == "model1":
- decoder = asr_model.decoder
- else:
- decoder = asr_model.decoder2
-
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
- # logging.info(f"Beam_search: {beam_search}")
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.token_num_relax = token_num_relax
- self.decoding_ind = decoding_ind
- self.decoding_mode = decoding_mode
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ) -> List[
- Tuple[
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- feats_raw = feats.clone().to(self.device)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
- # b. Forward Encoder
- _, enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
- if isinstance(enc, tuple):
- enc = enc[0]
- assert len(enc) == 1, len(enc)
- if self.decoding_mode == "model1":
- predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
- else:
- enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
- predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
-
- scama_mask = predictor_outs[4]
- pre_token_length = predictor_outs[1]
- pre_acoustic_embeds = predictor_outs[0]
- maxlen = pre_token_length.sum().item() + self.token_num_relax
- minlen = max(0, pre_token_length.sum().item() - self.token_num_relax)
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=enc[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=self.maxlenratio,
- minlenratio=self.minlenratio, maxlen=int(maxlen), minlen=int(minlen),
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
- token = list(filter(lambda x: x != "<gbg>", token))
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
- results.append((text, token, token_int, hyp))
-
- assert check_return_type(results)
- return results
-
-
-def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- ngram_file: Optional[str] = None,
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- ngram_file=ngram_file,
- nbest=nbest,
- num_workers=num_workers,
- token_num_relax=token_num_relax,
- decoding_ind=decoding_ind,
- decoding_mode=decoding_mode,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- ngram_file: Optional[str] = None,
- cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- token_num_relax: int = 1,
- decoding_ind: int = 0,
- decoding_mode: str = "model1",
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- if param_dict is not None and "decoding_model" in param_dict:
- if param_dict["decoding_model"] == "fast":
- decoding_ind = 0
- decoding_mode = "model1"
- elif param_dict["decoding_model"] == "normal":
- decoding_ind = 0
- decoding_mode = "model2"
- elif param_dict["decoding_model"] == "offline":
- decoding_ind = 1
- decoding_mode = "model2"
- else:
- raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- ngram_file=ngram_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- token_num_relax=token_num_relax,
- decoding_ind=decoding_ind,
- decoding_mode=decoding_mode,
- )
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- logging.info(f"Utterance: {key}")
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
-
- if text is not None:
- text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
- return asr_result_list
-
- return _forward
-
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
- group.add_argument("--token_num_relax", type=int, default=1, help="")
- group.add_argument("--decoding_ind", type=int, default=0, help="")
- group.add_argument("--decoding_mode", type=str, default="model1", help="")
- group.add_argument(
- "--ctc_weight2",
- type=float,
- default=0.0,
- help="CTC weight in joint decoding",
- )
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index a43472c..fd973a4 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
@@ -8,6 +11,12 @@
# for ASR Training
def parse_args():
parser = ASRTask.get_parser()
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="asr",
+ help=" ",
+ )
parser.add_argument(
"--gpu_id",
type=int,
@@ -19,7 +28,17 @@
def main(args=None, cmd=None):
+
# for ASR Training
+ if args.mode == "asr":
+ from funasr.tasks.asr import ASRTask
+ if args.mode == "paraformer":
+ from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+ if args.mode == "uniasr":
+ from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+ if args.mode == "rnnt":
+ from funasr.tasks.asr import ASRTransducerTask as ASRTask
+
ASRTask.main(args=args, cmd=cmd)
@@ -27,8 +46,7 @@
args = parse_args()
# setup local gpu_id
- if args.ngpu > 0:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
@@ -39,9 +57,10 @@
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
- if args.batch_size is not None and args.ngpu > 0:
+ if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None and args.ngpu > 0:
+ if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)
+
diff --git a/funasr/bin/asr_train_paraformer.py b/funasr/bin/asr_train_paraformer.py
deleted file mode 100755
index 76943d5..0000000
--- a/funasr/bin/asr_train_paraformer.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
deleted file mode 100755
index fe418db..0000000
--- a/funasr/bin/asr_train_transducer.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.asr import ASRTransducerTask
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTransducerTask.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTransducerTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/asr_train_uniasr.py b/funasr/bin/asr_train_uniasr.py
deleted file mode 100755
index a40b503..0000000
--- a/funasr/bin/asr_train_uniasr.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-
-from funasr.tasks.asr import ASRTaskUniASR
-
-
-# for ASR Training
-def parse_args():
- parser = ASRTaskUniASR.get_parser()
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
-
-
-def main(args=None, cmd=None):
- # for ASR Training
- ASRTaskUniASR.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
-
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
-
- main(args=args)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
deleted file mode 100644
index 5c30fdb..0000000
--- a/funasr/bin/build_trainer.py
+++ /dev/null
@@ -1,146 +0,0 @@
-import os
-
-import yaml
-
-
-def update_dct(fin_configs, root):
- if root == {}:
- return {}
- for root_key, root_value in root.items():
- if not isinstance(root[root_key], dict):
- fin_configs[root_key] = root[root_key]
- else:
- if root_key in fin_configs.keys():
- result = update_dct(fin_configs[root_key], root[root_key])
- fin_configs[root_key] = result
- else:
- fin_configs[root_key] = root[root_key]
- return fin_configs
-
-
-def parse_args(mode):
- if mode == "asr":
- from funasr.tasks.asr import ASRTask as ASRTask
- elif mode == "paraformer":
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- elif mode == "paraformer_vad_punc":
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- elif mode == "uniasr":
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- elif mode == "mfcca":
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
- elif mode == "tp":
- from funasr.tasks.asr import ASRTaskAligner as ASRTask
- else:
- raise ValueError("Unknown mode: {}".format(mode))
- parser = ASRTask.get_parser()
- args = parser.parse_args()
- return args, ASRTask
-
-
-def build_trainer(modelscope_dict,
- data_dir,
- output_dir,
- train_set="train",
- dev_set="validation",
- distributed=False,
- dataset_type="small",
- batch_bins=None,
- max_epoch=None,
- optim=None,
- lr=None,
- scheduler=None,
- scheduler_conf=None,
- specaug=None,
- specaug_conf=None,
- param_dict=None,
- **kwargs):
- mode = modelscope_dict['mode']
- args, ASRTask = parse_args(mode=mode)
- # ddp related
- if args.local_rank is not None:
- distributed = True
- else:
- distributed = False
- args.local_rank = args.local_rank if args.local_rank is not None else 0
- local_rank = args.local_rank
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
-
- config = modelscope_dict['am_model_config']
- finetune_config = modelscope_dict['finetune_config']
- init_param = modelscope_dict['init_model']
- cmvn_file = modelscope_dict['cmvn_file']
- seg_dict_file = modelscope_dict['seg_dict']
-
- # overwrite parameters
- with open(config) as f:
- configs = yaml.safe_load(f)
- with open(finetune_config) as f:
- finetune_configs = yaml.safe_load(f)
- # set data_types
- if dataset_type == "large":
- if 'data_types' not in finetune_configs['dataset_conf']:
- finetune_configs["dataset_conf"]["data_types"] = "sound,text"
- finetune_configs = update_dct(configs, finetune_configs)
- for key, value in finetune_configs.items():
- if hasattr(args, key):
- setattr(args, key, value)
-
- # prepare data
- args.dataset_type = dataset_type
- if args.dataset_type == "small":
- args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
- ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
- args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
- ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
- elif args.dataset_type == "large":
- args.train_data_file = None
- args.valid_data_file = None
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- args.init_param = [init_param]
- args.cmvn_file = cmvn_file
- if os.path.exists(seg_dict_file):
- args.seg_dict_file = seg_dict_file
- else:
- args.seg_dict_file = None
- args.data_dir = data_dir
- args.train_set = train_set
- args.dev_set = dev_set
- args.output_dir = output_dir
- args.gpu_id = args.local_rank
- args.config = finetune_config
- if optim is not None:
- args.optim = optim
- if lr is not None:
- args.optim_conf["lr"] = lr
- if scheduler is not None:
- args.scheduler = scheduler
- if scheduler_conf is not None:
- args.scheduler_conf = scheduler_conf
- if specaug is not None:
- args.specaug = specaug
- if specaug_conf is not None:
- args.specaug_conf = specaug_conf
- if max_epoch is not None:
- args.max_epoch = max_epoch
- if batch_bins is not None:
- if args.dataset_type == "small":
- args.batch_bins = batch_bins
- elif args.dataset_type == "large":
- args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- if args.normalize in ["null", "none", "None"]:
- args.normalize = None
- if args.patience in ["null", "none", "None"]:
- args.patience = None
- args.local_rank = local_rank
- args.distributed = distributed
- ASRTask.finetune_args = args
-
- return ASRTask
diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
new file mode 100755
index 0000000..f2dcb1e
--- /dev/null
+++ b/funasr/bin/diar_infer.py
@@ -0,0 +1,347 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+from collections import OrderedDict
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.diar import DiarTask
+from funasr.tasks.diar import EENDOLADiarTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from scipy.ndimage import median_filter
+from funasr.utils.misc import statistic_model_parameters
+from funasr.datasets.iterable_dataset import load_bytes
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+
+class Speech2DiarizationEEND:
+ """Speech2Diarlization class
+
+ Examples:
+ >>> import soundfile
+ >>> import numpy as np
+ >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
+ >>> profile = np.load("profiles.npy")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2diar(audio, profile)
+ {"spk1": [(int, int), ...], ...}
+
+ """
+
+ def __init__(
+ self,
+ diar_train_config: Union[Path, str] = None,
+ diar_model_file: Union[Path, str] = None,
+ device: str = "cpu",
+ dtype: str = "float32",
+ ):
+ assert check_argument_types()
+
+ # 1. Build Diarization model
+ diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
+ config_file=diar_train_config,
+ model_file=diar_model_file,
+ device=device
+ )
+ frontend = None
+ if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
+ frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
+
+ # set up seed for eda
+ np.random.seed(diar_train_args.seed)
+ torch.manual_seed(diar_train_args.seed)
+ torch.cuda.manual_seed(diar_train_args.seed)
+ os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
+ logging.info("diar_model: {}".format(diar_model))
+ logging.info("diar_train_args: {}".format(diar_train_args))
+ diar_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.diar_model = diar_model
+ self.diar_train_args = diar_train_args
+ self.device = device
+ self.dtype = dtype
+ self.frontend = frontend
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ speech_lengths: Union[torch.Tensor, np.ndarray] = None
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ diarization results
+
+ """
+ assert check_argument_types()
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.diar_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ batch = {"speech": feats, "speech_lengths": feats_len}
+ batch = to_device(batch, device=self.device)
+ results = self.diar_model.estimate_sequential(**batch)
+
+ return results
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ):
+ """Build Speech2Diarization instance from the pretrained model.
+
+ Args:
+ model_tag (Optional[str]): Model tag of the pretrained models.
+ Currently, the tags of espnet_model_zoo are supported.
+
+ Returns:
+ Speech2Diarization: Speech2Diarization instance.
+
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2DiarizationEEND(**kwargs)
+
+
+class Speech2DiarizationSOND:
+ """Speech2Xvector class
+
+ Examples:
+ >>> import soundfile
+ >>> import numpy as np
+ >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
+ >>> profile = np.load("profiles.npy")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2diar(audio, profile)
+ {"spk1": [(int, int), ...], ...}
+
+ """
+
+ def __init__(
+ self,
+ diar_train_config: Union[Path, str] = None,
+ diar_model_file: Union[Path, str] = None,
+ device: Union[str, torch.device] = "cpu",
+ batch_size: int = 1,
+ dtype: str = "float32",
+ streaming: bool = False,
+ smooth_size: int = 83,
+ dur_threshold: float = 10,
+ ):
+ assert check_argument_types()
+
+ # TODO: 1. Build Diarization model
+ diar_model, diar_train_args = DiarTask.build_model_from_file(
+ config_file=diar_train_config,
+ model_file=diar_model_file,
+ device=device
+ )
+ logging.info("diar_model: {}".format(diar_model))
+ logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
+ logging.info("diar_train_args: {}".format(diar_train_args))
+ diar_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.diar_model = diar_model
+ self.diar_train_args = diar_train_args
+ self.token_list = diar_train_args.token_list
+ self.smooth_size = smooth_size
+ self.dur_threshold = dur_threshold
+ self.device = device
+ self.dtype = dtype
+
+ def smooth_multi_labels(self, multi_label):
+ multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
+ return multi_label
+
+ @staticmethod
+ def calc_spk_turns(label_arr, spk_list):
+ turn_list = []
+ length = label_arr.shape[0]
+ n_spk = label_arr.shape[1]
+ for k in range(n_spk):
+ if spk_list[k] == "None":
+ continue
+ in_utt = False
+ start = 0
+ for i in range(length):
+ if label_arr[i, k] == 1 and in_utt is False:
+ start = i
+ in_utt = True
+ if label_arr[i, k] == 0 and in_utt is True:
+ turn_list.append([spk_list[k], start, i - start])
+ in_utt = False
+ if in_utt:
+ turn_list.append([spk_list[k], start, length - start])
+ return turn_list
+
+ @staticmethod
+ def seq2arr(seq, vec_dim=8):
+ def int2vec(x, vec_dim=8, dtype=np.int):
+ b = ('{:0' + str(vec_dim) + 'b}').format(x)
+ # little-endian order: lower bit first
+ return (np.array(list(b)[::-1]) == '1').astype(dtype)
+
+ # process oov
+ seq = np.array([int(x) for x in seq])
+ new_seq = []
+ for i, x in enumerate(seq):
+ if x < 2 ** vec_dim:
+ new_seq.append(x)
+ else:
+ idx_list = np.where(seq < 2 ** vec_dim)[0]
+ idx = np.abs(idx_list - i).argmin()
+ new_seq.append(seq[idx_list[idx]])
+ return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
+
+ def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
+ logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
+ # upsampling outputs to match inputs
+ ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
+ logits_idx = F.upsample(
+ logits_idx.unsqueeze(1).float(),
+ size=(ut, ),
+ mode="nearest",
+ ).squeeze(1).long()
+ logits_idx = logits_idx[0].tolist()
+ pse_labels = [self.token_list[x] for x in logits_idx]
+ if output_format == "pse_labels":
+ return pse_labels, None
+
+ multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
+ multi_labels = self.smooth_multi_labels(multi_labels)
+ if output_format == "binary_labels":
+ return multi_labels, None
+
+ spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
+ spk_turns = self.calc_spk_turns(multi_labels, spk_list)
+ results = OrderedDict()
+ for spk, st, dur in spk_turns:
+ if spk not in results:
+ results[spk] = []
+ if dur > self.dur_threshold:
+ results[spk].append((st, st+dur))
+
+ # sort segments in start time ascending
+ for spk in results:
+ results[spk] = sorted(results[spk], key=lambda x: x[0])
+
+ return results, pse_labels
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ profile: Union[torch.Tensor, np.ndarray],
+ output_format: str = "speaker_turn"
+ ):
+ """Inference
+
+ Args:
+ speech: Input speech data
+ profile: Speaker profiles
+ Returns:
+ diarization results for each speaker
+
+ """
+ assert check_argument_types()
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if isinstance(profile, np.ndarray):
+ profile = torch.tensor(profile)
+
+ # data: (Nsamples,) -> (1, Nsamples)
+ speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
+ # lengths: (1,)
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
+ batch = {"speech": speech, "speech_lengths": speech_lengths,
+ "profile": profile, "profile_lengths": profile_lengths}
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ logits = self.diar_model.prediction_forward(**batch)
+ results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
+
+ return results, pse_labels
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ):
+ """Build Speech2Xvector instance from the pretrained model.
+
+ Args:
+ model_tag (Optional[str]): Model tag of the pretrained models.
+ Currently, the tags of espnet_model_zoo are supported.
+
+ Returns:
+ Speech2Xvector: Speech2Xvector instance.
+
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2DiarizationSOND(**kwargs)
+
+
+
+
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 07974c0..e0d900e 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -15,6 +16,375 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+from collections import OrderedDict
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+from typeguard import check_argument_types
+from typeguard import check_return_type
+from scipy.signal import medfilt
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.diar import DiarTask
+from funasr.tasks.diar import EENDOLADiarTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from scipy.ndimage import median_filter
+from funasr.utils.misc import statistic_model_parameters
+from funasr.datasets.iterable_dataset import load_bytes
+from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
+
+def inference_sond(
+ diar_train_config: str,
+ diar_model_file: str,
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 0,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ smooth_size: int = 83,
+ dur_threshold: int = 10,
+ out_format: str = "vad",
+ param_dict: Optional[dict] = None,
+ mode: str = "sond",
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("param_dict: {}".format(param_dict))
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2a. Build speech2xvec [Optional]
+ if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
+ assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
+ assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
+ sv_train_config = param_dict["sv_train_config"]
+ sv_model_file = param_dict["sv_model_file"]
+ if "model_dir" in param_dict:
+ sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
+ sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
+ from funasr.bin.sv_infer import Speech2Xvector
+ speech2xvector_kwargs = dict(
+ sv_train_config=sv_train_config,
+ sv_model_file=sv_model_file,
+ device=device,
+ dtype=dtype,
+ streaming=streaming,
+ embedding_node="resnet1_dense"
+ )
+ logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
+ speech2xvector = Speech2Xvector.from_pretrained(
+ model_tag=model_tag,
+ **speech2xvector_kwargs,
+ )
+ speech2xvector.sv_model.eval()
+
+ # 2b. Build speech2diar
+ speech2diar_kwargs = dict(
+ diar_train_config=diar_train_config,
+ diar_model_file=diar_model_file,
+ device=device,
+ dtype=dtype,
+ streaming=streaming,
+ smooth_size=smooth_size,
+ dur_threshold=dur_threshold,
+ )
+ logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
+ speech2diar = Speech2DiarizationSOND.from_pretrained(
+ model_tag=model_tag,
+ **speech2diar_kwargs,
+ )
+ speech2diar.diar_model.eval()
+
+ def output_results_str(results: dict, uttid: str):
+ rst = []
+ mid = uttid.rsplit("-", 1)[0]
+ for key in results:
+ results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
+ if out_format == "vad":
+ for spk, segs in results.items():
+ rst.append("{} {}".format(spk, segs))
+ else:
+ template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
+ for spk, segs in results.items():
+ rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
+
+ return "\n".join(rst)
+
+ def _forward(
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
+ ):
+ logging.info("param_dict: {}".format(param_dict))
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, (list, tuple)):
+ if not isinstance(raw_inputs[0], List):
+ raw_inputs = [raw_inputs]
+
+ assert all([len(example) >= 2 for example in raw_inputs]), \
+ "The length of test case in raw_inputs must larger than 1 (>=2)."
+
+ def prepare_dataset():
+ for idx, example in enumerate(raw_inputs):
+ # read waveform file
+ example = [load_bytes(x) if isinstance(x, bytes) else x
+ for x in example]
+ example = [soundfile.read(x)[0] if isinstance(x, str) else x
+ for x in example]
+ # convert torch tensor to numpy array
+ example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
+ for x in example]
+ speech = example[0]
+ logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
+ profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
+ profile = torch.cat(profile, dim=0)
+ yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
+
+ loader = prepare_dataset()
+ else:
+ raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
+ else:
+ # 3. Build data-iterator
+ loader = DiarTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=None,
+ collate_fn=None,
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 7. Start for-loop
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ os.makedirs(output_path, exist_ok=True)
+ output_writer = open("{}/result.txt".format(output_path), "w")
+ pse_label_writer = open("{}/labels.txt".format(output_path), "w")
+ logging.info("Start to diarize...")
+ result_list = []
+ for idx, (keys, batch) in enumerate(loader):
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ results, pse_labels = speech2diar(**batch)
+ # Only supporting batch_size==1
+ key, value = keys[0], output_results_str(results, keys[0])
+ item = {"key": key, "value": value}
+ result_list.append(item)
+ if output_path is not None:
+ output_writer.write(value)
+ output_writer.flush()
+ pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
+ pse_label_writer.flush()
+
+ if idx % 100 == 0:
+ logging.info("Processing {:5d}: {}".format(idx, key))
+
+ if output_path is not None:
+ output_writer.close()
+ pse_label_writer.close()
+
+ return result_list
+
+ return _forward
+
+def inference_eend(
+ diar_train_config: str,
+ diar_model_file: str,
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ param_dict: Optional[dict] = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("param_dict: {}".format(param_dict))
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Build speech2diar
+ speech2diar_kwargs = dict(
+ diar_train_config=diar_train_config,
+ diar_model_file=diar_model_file,
+ device=device,
+ dtype=dtype,
+ )
+ logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
+ speech2diar = Speech2DiarizationEEND.from_pretrained(
+ model_tag=model_tag,
+ **speech2diar_kwargs,
+ )
+ speech2diar.diar_model.eval()
+
+ def output_results_str(results: dict, uttid: str):
+ rst = []
+ mid = uttid.rsplit("-", 1)[0]
+ for key in results:
+ results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
+ template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
+ for spk, segs in results.items():
+ rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
+
+ return "\n".join(rst)
+
+ def _forward(
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
+ ):
+ # 2. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
+ loader = EENDOLADiarTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
+ collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 3. Start for-loop
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ os.makedirs(output_path, exist_ok=True)
+ output_writer = open("{}/result.txt".format(output_path), "w")
+ result_list = []
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ results = speech2diar(**batch)
+
+ # post process
+ a = results[0][0].cpu().numpy()
+ a = medfilt(a, (11, 1))
+ rst = []
+ for spkid, frames in enumerate(a.T):
+ frames = np.pad(frames, (1, 1), 'constant')
+ changes, = np.where(np.diff(frames, axis=0) != 0)
+ fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
+ for s, e in zip(changes[::2], changes[1::2]):
+ st = s / 10.
+ dur = (e - s) / 10.
+ rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
+
+ # Only supporting batch_size==1
+ value = "\n".join(rst)
+ item = {"key": keys[0], "value": value}
+ result_list.append(item)
+ if output_path is not None:
+ output_writer.write(value)
+ output_writer.flush()
+
+ if output_path is not None:
+ output_writer.close()
+
+ return result_list
+
+ return _forward
+
+
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "sond":
+ return inference_sond(mode=mode, **kwargs)
+ elif mode == "sond_demo":
+ param_dict = {
+ "extract_profile": True,
+ "sv_train_config": "sv.yaml",
+ "sv_model_file": "sv.pb",
+ }
+ if "param_dict" in kwargs and kwargs["param_dict"] is not None:
+ for key in param_dict:
+ if key not in kwargs["param_dict"]:
+ kwargs["param_dict"][key] = param_dict[key]
+ else:
+ kwargs["param_dict"] = param_dict
+ return inference_sond(mode=mode, **kwargs)
+ elif mode == "eend-ola":
+ return inference_eend(mode=mode, **kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -125,32 +495,6 @@
return parser
-def inference_launch(mode, **kwargs):
- if mode == "sond":
- from funasr.bin.sond_inference import inference_modelscope
- return inference_modelscope(mode=mode, **kwargs)
- elif mode == "sond_demo":
- from funasr.bin.sond_inference import inference_modelscope
- param_dict = {
- "extract_profile": True,
- "sv_train_config": "sv.yaml",
- "sv_model_file": "sv.pb",
- }
- if "param_dict" in kwargs and kwargs["param_dict"] is not None:
- for key in param_dict:
- if key not in kwargs["param_dict"]:
- kwargs["param_dict"][key] = param_dict[key]
- else:
- kwargs["param_dict"] = param_dict
- return inference_modelscope(mode=mode, **kwargs)
- elif mode == "eend-ola":
- from funasr.bin.eend_ola_inference import inference_modelscope
- return inference_modelscope(mode=mode, **kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
@@ -178,7 +522,8 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- inference_launch(**kwargs)
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
diff --git a/funasr/bin/diar_train.py b/funasr/bin/diar_train.py
index f76d1b9..16a4bd0 100755
--- a/funasr/bin/diar_train.py
+++ b/funasr/bin/diar_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
deleted file mode 100755
index 87816dd..0000000
--- a/funasr/bin/eend_ola_inference.py
+++ /dev/null
@@ -1,429 +0,0 @@
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from scipy.signal import medfilt
-from typeguard import check_argument_types
-
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
-from funasr.tasks.diar import EENDOLADiarTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-class Speech2Diarization:
- """Speech2Diarlization class
-
- Examples:
- >>> import soundfile
- >>> import numpy as np
- >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
- >>> profile = np.load("profiles.npy")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2diar(audio, profile)
- {"spk1": [(int, int), ...], ...}
-
- """
-
- def __init__(
- self,
- diar_train_config: Union[Path, str] = None,
- diar_model_file: Union[Path, str] = None,
- device: str = "cpu",
- dtype: str = "float32",
- ):
- assert check_argument_types()
-
- # 1. Build Diarization model
- diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
- config_file=diar_train_config,
- model_file=diar_model_file,
- device=device
- )
- frontend = None
- if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
- frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
-
- # set up seed for eda
- np.random.seed(diar_train_args.seed)
- torch.manual_seed(diar_train_args.seed)
- torch.cuda.manual_seed(diar_train_args.seed)
- os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
- logging.info("diar_model: {}".format(diar_model))
- logging.info("diar_train_args: {}".format(diar_train_args))
- diar_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.diar_model = diar_model
- self.diar_train_args = diar_train_args
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- diarization results
-
- """
- assert check_argument_types()
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.diar_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- batch = {"speech": feats, "speech_lengths": feats_len}
- batch = to_device(batch, device=self.device)
- results = self.diar_model.estimate_sequential(**batch)
-
- return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ):
- """Build Speech2Diarization instance from the pretrained model.
-
- Args:
- model_tag (Optional[str]): Model tag of the pretrained models.
- Currently, the tags of espnet_model_zoo are supported.
-
- Returns:
- Speech2Diarization: Speech2Diarization instance.
-
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2Diarization(**kwargs)
-
-
-def inference_modelscope(
- diar_train_config: str,
- diar_model_file: str,
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- param_dict: Optional[dict] = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("param_dict: {}".format(param_dict))
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Build speech2diar
- speech2diar_kwargs = dict(
- diar_train_config=diar_train_config,
- diar_model_file=diar_model_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2Diarization.from_pretrained(
- model_tag=model_tag,
- **speech2diar_kwargs,
- )
- speech2diar.diar_model.eval()
-
- def output_results_str(results: dict, uttid: str):
- rst = []
- mid = uttid.rsplit("-", 1)[0]
- for key in results:
- results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
- template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
- for spk, segs in results.items():
- rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
-
- return "\n".join(rst)
-
- def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
- ):
- # 2. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
- loader = EENDOLADiarTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
- collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 3. Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- os.makedirs(output_path, exist_ok=True)
- output_writer = open("{}/result.txt".format(output_path), "w")
- result_list = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- results = speech2diar(**batch)
-
- # post process
- a = results[0][0].cpu().numpy()
- a = medfilt(a, (11, 1))
- rst = []
- for spkid, frames in enumerate(a.T):
- frames = np.pad(frames, (1, 1), 'constant')
- changes, = np.where(np.diff(frames, axis=0) != 0)
- fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
- for s, e in zip(changes[::2], changes[1::2]):
- st = s / 10.
- dur = (e - s) / 10.
- rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
-
- # Only supporting batch_size==1
- value = "\n".join(rst)
- item = {"key": keys[0], "value": value}
- result_list.append(item)
- if output_path is not None:
- output_writer.write(value)
- output_writer.flush()
-
- if output_path is not None:
- output_writer.close()
-
- return result_list
-
- return _forward
-
-
-def inference(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- diar_train_config: Optional[str],
- diar_model_file: Optional[str],
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 0,
- seed: int = 0,
- num_workers: int = 1,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- smooth_size: int = 83,
- dur_threshold: int = 10,
- out_format: str = "vad",
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- diar_train_config=diar_train_config,
- diar_model_file=diar_model_file,
- output_dir=output_dir,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- model_tag=model_tag,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- smooth_size=smooth_size,
- dur_threshold=dur_threshold,
- out_format=out_format,
- **kwargs,
- )
-
- return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Speaker verification/x-vector extraction",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--diar_train_config",
- type=str,
- help="diarization training configuration",
- )
- group.add_argument(
- "--diar_model_file",
- type=str,
- help="diarization model parameter file",
- )
- group.add_argument(
- "--dur_threshold",
- type=int,
- default=10,
- help="The threshold for short segments in number frames"
- )
- parser.add_argument(
- "--smooth_size",
- type=int,
- default=83,
- help="The smoothing window length in number frames"
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument("--streaming", type=str2bool, default=False)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- logging.info("args: {}".format(kwargs))
- if args.output_dir is None:
- jobid, n_gpu = 1, 1
- gpuid = args.gpuid_list.split(",")[jobid - 1]
- else:
- jobid = int(args.output_dir.split(".")[-1])
- n_gpu = len(args.gpuid_list.split(","))
- gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- results_list = inference(**kwargs)
- for results in results_list:
- print("{} {}".format(results["key"], results["value"]))
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
deleted file mode 100755
index 198d578..0000000
--- a/funasr/bin/lm_calc_perplexity.py
+++ /dev/null
@@ -1,211 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.lm import LMTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-
-def calc_perplexity(
- output_dir: str,
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float],
- allow_variable_data_keys: bool,
-):
- assert check_argument_types()
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1:
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build LM
- model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
- # Wrape model to make model.nll() data-parallel
- wrapped_model = ForwardAdaptor(model, "nll")
- wrapped_model.to(dtype=getattr(torch, dtype)).eval()
- logging.info(f"Model:\n{model}")
-
- # 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=LMTask.build_preprocess_fn(train_args, False),
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 4. Start for-loop
- with DatadirWriter(output_dir) as writer:
- total_nll = 0.0
- total_ntokens = 0
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- # NOTE(kamo): data_parallel also should work with ngpu=1,
- # but for debuggability it's better to keep this block.
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
-
- assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
- # nll: (B, L) -> (B,)
- nll = nll.detach().cpu().numpy().sum(1)
- # lengths: (B,)
- lengths = lengths.detach().cpu().numpy()
- total_nll += nll.sum()
- total_ntokens += lengths.sum()
-
- for key, _nll, ntoken in zip(keys, nll, lengths):
- if log_base is None:
- utt_ppl = np.exp(_nll / ntoken)
- else:
- utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
-
- # Write PPL of each utts for debugging or analysis
- writer["utt2nll"][key] = str(-_nll)
- writer["utt2ppl"][key] = str(utt_ppl)
- writer["utt2ntokens"][key] = str(ntoken)
-
- if log_base is None:
- ppl = np.exp(total_nll / total_ntokens)
- else:
- ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
- with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f:
- f.write(f"{ppl}\n")
- with (Path(output_dir) / "base").open("w", encoding="utf-8") as f:
- if log_base is None:
- _log_base = np.e
- else:
- _log_base = log_base
- f.write(f"{_log_base}\n")
- logging.info(f"PPL={ppl}")
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Calc perplexity",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument(
- "--log_base",
- type=float_or_none,
- default=None,
- help="The base of logarithm for Perplexity. "
- "If None, napier's constant is used.",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- calc_perplexity(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
deleted file mode 100644
index 76de6df..0000000
--- a/funasr/bin/lm_inference.py
+++ /dev/null
@@ -1,406 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-import os
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from typeguard import check_argument_types
-
-from funasr.tasks.lm import LMTask
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-def inference(
- output_dir: str,
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float],
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[List[Any], bytes, str] = None,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- output_dir=output_dir,
- raw_inputs=raw_inputs,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- train_config=train_config,
- model_file=model_file,
- log_base = log_base,
- allow_variable_data_keys = allow_variable_data_keys,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- log_base: Optional[float] = 10,
- allow_variable_data_keys: bool = False,
- split_with_space: Optional[bool] = False,
- seg_dict_file: Optional[str] = None,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build Model
- model, train_args = LMTask.build_model_from_file(
- train_config, model_file, device)
- wrapped_model = ForwardAdaptor(model, "nll")
- wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- logging.info(f"Model:\n{model}")
-
- preprocessor = LMPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file
- )
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: dict = None,
- ):
- results = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "lm demo"
- if line=="":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- batch = {}
- batch['text'] = line
- if preprocessor != None:
- batch = preprocessor(key, batch)
-
- # Force data-precision
- for name in batch:
- value = batch[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f"All values must be converted to np.ndarray object "
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
- # Cast to desired type
- if value.dtype.kind == "f":
- value = value.astype("float32")
- elif value.dtype.kind == "i":
- value = value.astype("long")
- else:
- raise NotImplementedError(f"Not supported dtype: {value.dtype}")
- batch[name] = value
-
- batch["text_lengths"] = torch.from_numpy(
- np.array([len(batch["text"])], dtype='int32'))
- batch["text"] = np.expand_dims(batch["text"], axis=0)
-
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
- ## compute ppl
- ppl_out_batch = ""
- ids2tokens = preprocessor.token_id_converter.ids2tokens
- for sent_ids, sent_nll in zip(batch['text'], nll):
- pre_word = "<s>"
- cur_word = None
- sent_lst = ids2tokens(sent_ids) + ['</s>']
- ppl_out = " ".join(sent_lst) + "\n"
- for word, word_nll in zip(sent_lst, sent_nll):
- cur_word = word
- word_nll = -word_nll.cpu()
- if log_base is None:
- word_prob = np.exp(word_nll)
- else:
- word_prob = log_base ** (word_nll / np.log(log_base))
- ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
- cur=cur_word,
- pre=pre_word,
- prob=round(word_prob.item(), 8),
- word_nll=round(word_nll.item(), 8)
- )
- pre_word = cur_word
-
- sent_nll_mean = sent_nll.mean().cpu().numpy()
- sent_nll_sum = sent_nll.sum().cpu().numpy()
- if log_base is None:
- sent_ppl = np.exp(sent_nll_mean)
- else:
- sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
- ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
- sent_nll=round(-sent_nll_sum.item(), 4),
- sent_ppl=round(sent_ppl.item(), 4)
- )
- ppl_out_batch += ppl_out
- item = {'key': key, 'value': ppl_out}
- if writer is not None:
- writer["ppl"][key+":\n"] = ppl_out
- results.append(item)
-
- return results
-
- # 3. Build data-iterator
- loader = LMTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=preprocessor,
- collate_fn=LMTask.build_collate_fn(train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 4. Start for-loop
- total_nll = 0.0
- total_ntokens = 0
- ppl_out_all = ""
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- ppl_out_batch = ""
- with torch.no_grad():
- batch = to_device(batch, device)
- if ngpu <= 1:
- # NOTE(kamo): data_parallel also should work with ngpu=1,
- # but for debuggability it's better to keep this block.
- nll, lengths = wrapped_model(**batch)
- else:
- nll, lengths = data_parallel(
- wrapped_model, (), range(ngpu), module_kwargs=batch
- )
- ## print ppl
- ids2tokens = preprocessor.token_id_converter.ids2tokens
- for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
- pre_word = "<s>"
- cur_word = None
- sent_lst = ids2tokens(sent_ids) + ['</s>']
- ppl_out = " ".join(sent_lst) + "\n"
- for word, word_nll in zip(sent_lst, sent_nll):
- cur_word = word
- word_nll = -word_nll.cpu()
- if log_base is None:
- word_prob = np.exp(word_nll)
- else:
- word_prob = log_base ** (word_nll / np.log(log_base))
- ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
- cur=cur_word,
- pre=pre_word,
- prob=round(word_prob.item(), 8),
- word_nll=round(word_nll.item(), 8)
- )
- pre_word = cur_word
-
- sent_nll_mean = sent_nll.mean().cpu().numpy()
- sent_nll_sum = sent_nll.sum().cpu().numpy()
- if log_base is None:
- sent_ppl = np.exp(sent_nll_mean)
- else:
- sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
- ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
- sent_nll=round(-sent_nll_sum.item(), 4),
- sent_ppl=round(sent_ppl.item(), 4)
- )
- ppl_out_batch += ppl_out
- utt2nll = round(-sent_nll_sum.item(), 5)
- item = {'key': key, 'value': ppl_out}
- if writer is not None:
- writer["ppl"][key+":\n"] = ppl_out
- writer["utt2nll"][key] = str(utt2nll)
- results.append(item)
-
- ppl_out_all += ppl_out_batch
-
- assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
- # nll: (B, L) -> (B,)
- nll = nll.detach().cpu().numpy().sum(1)
- # lengths: (B,)
- lengths = lengths.detach().cpu().numpy()
- total_nll += nll.sum()
- total_ntokens += lengths.sum()
-
- if log_base is None:
- ppl = np.exp(total_nll / total_ntokens)
- else:
- ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
-
- avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
- total_nll=round(-total_nll.item(), 4),
- total_ppl=round(ppl.item(), 4)
- )
- item = {'key': 'AVG PPL', 'value': avg_ppl}
- ppl_out_all += avg_ppl
- if writer is not None:
- writer["ppl"]["AVG PPL : "] = avg_ppl
- results.append(item)
-
- return results
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Calc perplexity",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument(
- "--log_base",
- type=float_or_none,
- default=10,
- help="The base of logarithm for Perplexity. "
- "If None, napier's constant is used.",
- required=False
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- required=False
- )
- group.add_argument(
- "--raw_inputs",
- type=str,
- required=False
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group.add_argument("--split_with_space", type=str2bool, default=False)
- group.add_argument("--seg_dict_file", type=str_or_none)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- inference(**kwargs)
-
-if __name__ == "__main__":
- main()
-
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
index dc6414f..3d38f64 100644
--- a/funasr/bin/lm_inference_launch.py
+++ b/funasr/bin/lm_inference_launch.py
@@ -1,6 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
-
-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -14,8 +15,294 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+def inference_lm(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ log_base: Optional[float] = 10,
+ allow_variable_data_keys: bool = False,
+ split_with_space: Optional[bool] = False,
+ seg_dict_file: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build Model
+ model, train_args = LMTask.build_model_from_file(
+ train_config, model_file, device)
+ wrapped_model = ForwardAdaptor(model, "nll")
+ wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ logging.info(f"Model:\n{model}")
+
+ preprocessor = LMPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file
+ )
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ else:
+ writer = None
+
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "lm demo"
+ if line == "":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ batch = {}
+ batch['text'] = line
+ if preprocessor != None:
+ batch = preprocessor(key, batch)
+
+ # Force data-precision
+ for name in batch:
+ value = batch[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype("float32")
+ elif value.dtype.kind == "i":
+ value = value.astype("long")
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ batch[name] = value
+
+ batch["text_lengths"] = torch.from_numpy(
+ np.array([len(batch["text"])], dtype='int32'))
+ batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## compute ppl
+ ppl_out_batch = ""
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for sent_ids, sent_nll in zip(batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key + ":\n"] = ppl_out
+ results.append(item)
+
+ return results
+
+ # 3. Build data-iterator
+ loader = LMTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=preprocessor,
+ collate_fn=LMTask.build_collate_fn(train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 4. Start for-loop
+ total_nll = 0.0
+ total_ntokens = 0
+ ppl_out_all = ""
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ ppl_out_batch = ""
+ with torch.no_grad():
+ batch = to_device(batch, device)
+ if ngpu <= 1:
+ # NOTE(kamo): data_parallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ nll, lengths = wrapped_model(**batch)
+ else:
+ nll, lengths = data_parallel(
+ wrapped_model, (), range(ngpu), module_kwargs=batch
+ )
+ ## print ppl
+ ids2tokens = preprocessor.token_id_converter.ids2tokens
+ for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+ pre_word = "<s>"
+ cur_word = None
+ sent_lst = ids2tokens(sent_ids) + ['</s>']
+ ppl_out = " ".join(sent_lst) + "\n"
+ for word, word_nll in zip(sent_lst, sent_nll):
+ cur_word = word
+ word_nll = -word_nll.cpu()
+ if log_base is None:
+ word_prob = np.exp(word_nll)
+ else:
+ word_prob = log_base ** (word_nll / np.log(log_base))
+ ppl_out += ' p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+ cur=cur_word,
+ pre=pre_word,
+ prob=round(word_prob.item(), 8),
+ word_nll=round(word_nll.item(), 8)
+ )
+ pre_word = cur_word
+
+ sent_nll_mean = sent_nll.mean().cpu().numpy()
+ sent_nll_sum = sent_nll.sum().cpu().numpy()
+ if log_base is None:
+ sent_ppl = np.exp(sent_nll_mean)
+ else:
+ sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+ ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+ sent_nll=round(-sent_nll_sum.item(), 4),
+ sent_ppl=round(sent_ppl.item(), 4)
+ )
+ ppl_out_batch += ppl_out
+ utt2nll = round(-sent_nll_sum.item(), 5)
+ item = {'key': key, 'value': ppl_out}
+ if writer is not None:
+ writer["ppl"][key + ":\n"] = ppl_out
+ writer["utt2nll"][key] = str(utt2nll)
+ results.append(item)
+
+ ppl_out_all += ppl_out_batch
+
+ assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+ # nll: (B, L) -> (B,)
+ nll = nll.detach().cpu().numpy().sum(1)
+ # lengths: (B,)
+ lengths = lengths.detach().cpu().numpy()
+ total_nll += nll.sum()
+ total_ntokens += lengths.sum()
+
+ if log_base is None:
+ ppl = np.exp(total_nll / total_ntokens)
+ else:
+ ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+ avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+ total_nll=round(-total_nll.item(), 4),
+ total_ppl=round(ppl.item(), 4)
+ )
+ item = {'key': 'AVG PPL', 'value': avg_ppl}
+ ppl_out_all += avg_ppl
+ if writer is not None:
+ writer["ppl"]["AVG PPL : "] = avg_ppl
+ results.append(item)
+
+ return results
+
+ return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "transformer":
+ return inference_lm(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Calc perplexity",
@@ -89,14 +376,6 @@
group.add_argument("--model_file", type=str)
group.add_argument("--mode", type=str, default="lm")
return parser
-
-def inference_launch(mode, **kwargs):
- if mode == "transformer":
- from funasr.bin.lm_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index 8641465..22b5f9c 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/modelscope_infer.py b/funasr/bin/modelscope_infer.py
deleted file mode 100755
index bc24340..0000000
--- a/funasr/bin/modelscope_infer.py
+++ /dev/null
@@ -1,90 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-import os
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- description="decoding configs",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument("--model_name",
- type=str,
- default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- help="model name in modelscope")
- parser.add_argument("--model_revision",
- type=str,
- default="v1.0.4",
- help="model revision in modelscope")
- parser.add_argument("--local_model_path",
- type=str,
- default=None,
- help="local model path, usually for fine-tuning")
- parser.add_argument("--wav_list",
- type=str,
- help="input wav list")
- parser.add_argument("--output_file",
- type=str,
- help="saving decoding results")
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- args = parser.parse_args()
-
- # set logging messages
- logging.basicConfig(
- level=logging.INFO,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("Decoding args: {}".format(args))
-
- # gpu setting
- if args.ngpu > 0:
- jobid = int(args.output_file.split(".")[-1])
- gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
-
- if args.local_model_path is None:
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model="damo/{}".format(args.model_name),
- model_revision=args.model_revision)
- else:
- inference_pipeline = pipeline(
- task=Tasks.auto_speech_recognition,
- model=args.local_model_path)
-
-
- with open(args.wav_list, 'r') as f_wav:
- wav_lines = f_wav.readlines()
-
- with open(args.output_file, "w") as f_out:
- for line in wav_lines:
- wav_id, wav_path = line.strip().split()
- logging.info("decoding, utt_id: ['{}']".format(wav_id))
- rec_result = inference_pipeline(audio_in=wav_path)
- if 'text' in rec_result:
- text = rec_result["text"]
- else:
- text = ''
- f_out.write(wav_id + " " + text + "\n")
- logging.info("best hypo: {} \n".format(text))
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
new file mode 100644
index 0000000..4b6cd27
--- /dev/null
+++ b/funasr/bin/punc_infer.py
@@ -0,0 +1,271 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+from pathlib import Path
+import sys
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.punctuation import PunctuationTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.datasets.preprocessor import split_to_mini_sentence
+
+
+class Text2Punc:
+
+ def __init__(
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
+ ):
+ # Build Model
+ model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ self.device = device
+ # Wrape model to make model.nll() data-parallel
+ self.wrapped_model = ForwardAdaptor(model, "inference")
+ self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ # logging.info(f"Model:\n{model}")
+ self.punc_list = train_args.punc_list
+ self.period = 0
+ for i in range(len(self.punc_list)):
+ if self.punc_list[i] == ",":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "?":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "銆�":
+ self.period = i
+ self.preprocessor = CodeMixTokenizerCommonPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ )
+
+ @torch.no_grad()
+ def __call__(self, text: Union[list, str], split_size=20):
+ data = {"text": text}
+ result = self.preprocessor(data=data, uid="12938712838719")
+ split_text = self.preprocessor.pop_split_text_data(result)
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
+ mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+ assert len(mini_sentences) == len(mini_sentences_id)
+ cache_sent = []
+ cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+ new_mini_sentence = ""
+ new_mini_sentence_punc = []
+ cache_pop_trigger_limit = 200
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ data = {
+ "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+ "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+ }
+ data = to_device(data, self.device)
+ y, _ = self.wrapped_model(**data)
+ _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+ punctuations = indices
+ if indices.size()[0] != 1:
+ punctuations = torch.squeeze(indices)
+ assert punctuations.size()[0] == len(mini_sentence)
+
+ # Search for the last Period/QuestionMark as cache
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # The sentence it too long, cut off at a comma.
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.period
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+
+ # if len(punctuations) == 0:
+ # continue
+
+ punctuations_np = punctuations.cpu().numpy()
+ new_mini_sentence_punc += [int(x) for x in punctuations_np]
+ words_with_punc = []
+ for i in range(len(mini_sentence)):
+ if i > 0:
+ if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
+ words_with_punc.append(mini_sentence[i])
+ if self.punc_list[punctuations[i]] != "_":
+ words_with_punc.append(self.punc_list[punctuations[i]])
+ new_mini_sentence += "".join(words_with_punc)
+ # Add Period for the end of the sentence
+ new_mini_sentence_out = new_mini_sentence
+ new_mini_sentence_punc_out = new_mini_sentence_punc
+ if mini_sentence_i == len(mini_sentences) - 1:
+ if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+ new_mini_sentence_out = new_mini_sentence + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ return new_mini_sentence_out, new_mini_sentence_punc_out
+
+
+class Text2PuncVADRealtime:
+
+ def __init__(
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
+ ):
+ # Build Model
+ model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ self.device = device
+ # Wrape model to make model.nll() data-parallel
+ self.wrapped_model = ForwardAdaptor(model, "inference")
+ self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ # logging.info(f"Model:\n{model}")
+ self.punc_list = train_args.punc_list
+ self.period = 0
+ for i in range(len(self.punc_list)):
+ if self.punc_list[i] == ",":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "?":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "銆�":
+ self.period = i
+ self.preprocessor = CodeMixTokenizerCommonPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ )
+
+ @torch.no_grad()
+ def __call__(self, text: Union[list, str], cache: list, split_size=20):
+ if cache is not None and len(cache) > 0:
+ precache = "".join(cache)
+ else:
+ precache = ""
+ cache = []
+ data = {"text": precache + " " + text}
+ result = self.preprocessor(data=data, uid="12938712838719")
+ split_text = self.preprocessor.pop_split_text_data(result)
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
+ mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+ assert len(mini_sentences) == len(mini_sentences_id)
+ cache_sent = []
+ cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+ sentence_punc_list = []
+ sentence_words_list = []
+ cache_pop_trigger_limit = 200
+ skip_num = 0
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ data = {
+ "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+ "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+ "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
+ }
+ data = to_device(data, self.device)
+ y, _ = self.wrapped_model(**data)
+ _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+ punctuations = indices
+ if indices.size()[0] != 1:
+ punctuations = torch.squeeze(indices)
+ assert punctuations.size()[0] == len(mini_sentence)
+
+ # Search for the last Period/QuestionMark as cache
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # The sentence it too long, cut off at a comma.
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.period
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+
+ punctuations_np = punctuations.cpu().numpy()
+ sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+ sentence_words_list += mini_sentence
+
+ assert len(sentence_punc_list) == len(sentence_words_list)
+ words_with_punc = []
+ sentence_punc_list_out = []
+ for i in range(0, len(sentence_words_list)):
+ if i > 0:
+ if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+ sentence_words_list[i] = " " + sentence_words_list[i]
+ if skip_num < len(cache):
+ skip_num += 1
+ else:
+ words_with_punc.append(sentence_words_list[i])
+ if skip_num >= len(cache):
+ sentence_punc_list_out.append(sentence_punc_list[i])
+ if sentence_punc_list[i] != "_":
+ words_with_punc.append(sentence_punc_list[i])
+ sentence_out = "".join(words_with_punc)
+
+ sentenceEnd = -1
+ for i in range(len(sentence_punc_list) - 2, 1, -1):
+ if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+ sentenceEnd = i
+ break
+ cache_out = sentence_words_list[sentenceEnd + 1:]
+ if sentence_out[-1] in self.punc_list:
+ sentence_out = sentence_out[:-1]
+ sentence_punc_list_out[-1] = "_"
+ return sentence_out, sentence_punc_list_out, cache_out
+
+
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index b1d9235..7f60f81 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -1,5 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -14,6 +16,176 @@
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
+import argparse
+import logging
+from pathlib import Path
+import sys
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.punctuation import PunctuationTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.datasets.preprocessor import split_to_mini_sentence
+from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
+
+def inference_punc(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+ text2punc = Text2Punc(train_config, model_file, device)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ split_size = 20
+
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "demo"
+ if line == "":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ result, _ = text2punc(line)
+ item = {'key': key, 'value': result}
+ results.append(item)
+ return results
+
+ for inference_text, _, _ in data_path_and_name_and_type:
+ with open(inference_text, "r", encoding="utf-8") as fin:
+ for line in fin:
+ line = line.strip()
+ segs = line.split("\t")
+ if len(segs) != 2:
+ continue
+ key = segs[0]
+ if len(segs[1]) == 0:
+ continue
+ result, _ = text2punc(segs[1])
+ item = {'key': key, 'value': result}
+ results.append(item)
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path != None:
+ output_file_name = "infer.out"
+ Path(output_path).mkdir(parents=True, exist_ok=True)
+ output_file_path = (Path(output_path) / output_file_name).absolute()
+ with open(output_file_path, "w", encoding="utf-8") as fout:
+ for item_i in results:
+ key_out = item_i["key"]
+ value_out = item_i["value"]
+ fout.write(f"{key_out}\t{value_out}\n")
+ return results
+
+ return _forward
+
+def inference_punc_vad_realtime(
+ batch_size: int,
+ dtype: str,
+ ngpu: int,
+ seed: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ #cache: list,
+ key_file: Optional[str],
+ train_config: Optional[str],
+ model_file: Optional[str],
+ output_dir: Optional[str] = None,
+ param_dict: dict = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+ text2punc = Text2PuncVADRealtime(train_config, model_file, device)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[List[Any], bytes, str] = None,
+ output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
+ ):
+ results = []
+ split_size = 10
+ cache_in = param_dict["cache"]
+ if raw_inputs != None:
+ line = raw_inputs.strip()
+ key = "demo"
+ if line == "":
+ item = {'key': key, 'value': ""}
+ results.append(item)
+ return results
+ result, _, cache = text2punc(line, cache_in)
+ param_dict["cache"] = cache
+ item = {'key': key, 'value': result}
+ results.append(item)
+ return results
+
+ return results
+
+ return _forward
+
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "punc":
+ return inference_punc(**kwargs)
+ if mode == "punc_VadRealtime":
+ return inference_punc_vad_realtime(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -70,18 +242,6 @@
return parser
-def inference_launch(mode, **kwargs):
- if mode == "punc":
- from funasr.bin.punctuation_infer import inference_modelscope
- return inference_modelscope(**kwargs)
- if mode == "punc_VadRealtime":
- from funasr.bin.punctuation_infer_vadrealtime import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
@@ -105,7 +265,9 @@
kwargs.pop("gpuid_list", None)
kwargs.pop("njob", None)
- results = inference_launch(**kwargs)
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"])
+
if __name__ == "__main__":
diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py
index 61b63ec..aeded7b 100644
--- a/funasr/bin/punc_train.py
+++ b/funasr/bin/punc_train.py
@@ -1,4 +1,8 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import os
from funasr.tasks.punctuation import PunctuationTask
diff --git a/funasr/bin/punctuation_infer.py b/funasr/bin/punctuation_infer.py
deleted file mode 100644
index 077814d..0000000
--- a/funasr/bin/punctuation_infer.py
+++ /dev/null
@@ -1,320 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.punctuation import PunctuationTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.datasets.preprocessor import split_to_mini_sentence
-
-
-class Text2Punc:
-
- def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
- ):
- # Build Model
- model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
- self.device = device
- # Wrape model to make model.nll() data-parallel
- self.wrapped_model = ForwardAdaptor(model, "inference")
- self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- # logging.info(f"Model:\n{model}")
- self.punc_list = train_args.punc_list
- self.period = 0
- for i in range(len(self.punc_list)):
- if self.punc_list[i] == ",":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "?":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "銆�":
- self.period = i
- self.preprocessor = CodeMixTokenizerCommonPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- )
-
- @torch.no_grad()
- def __call__(self, text: Union[list, str], split_size=20):
- data = {"text": text}
- result = self.preprocessor(data=data, uid="12938712838719")
- split_text = self.preprocessor.pop_split_text_data(result)
- mini_sentences = split_to_mini_sentence(split_text, split_size)
- mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
- assert len(mini_sentences) == len(mini_sentences_id)
- cache_sent = []
- cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
- new_mini_sentence = ""
- new_mini_sentence_punc = []
- cache_pop_trigger_limit = 200
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence_id = mini_sentences_id[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
- data = {
- "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
- "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- }
- data = to_device(data, self.device)
- y, _ = self.wrapped_model(**data)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences) - 1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations) - 2, 1, -1):
- if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
-
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = self.period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
-
- # if len(punctuations) == 0:
- # continue
-
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += [int(x) for x in punctuations_np]
- words_with_punc = []
- for i in range(len(mini_sentence)):
- if i > 0:
- if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
- mini_sentence[i] = " " + mini_sentence[i]
- words_with_punc.append(mini_sentence[i])
- if self.punc_list[punctuations[i]] != "_":
- words_with_punc.append(self.punc_list[punctuations[i]])
- new_mini_sentence += "".join(words_with_punc)
- # Add Period for the end of the sentence
- new_mini_sentence_out = new_mini_sentence
- new_mini_sentence_punc_out = new_mini_sentence_punc
- if mini_sentence_i == len(mini_sentences) - 1:
- if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
- new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
- new_mini_sentence_out = new_mini_sentence + "銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- return new_mini_sentence_out, new_mini_sentence_punc_out
-
-
-def inference(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- output_dir: str,
- log_level: Union[int, str],
- train_config: Optional[str],
- model_file: Optional[str],
- key_file: Optional[str] = None,
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[List[Any], bytes, str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- output_dir=output_dir,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- train_config=train_config,
- model_file=model_file,
- param_dict=param_dict,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
- text2punc = Text2Punc(train_config, model_file, device)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
- ):
- results = []
- split_size = 20
-
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "demo"
- if line == "":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- result, _ = text2punc(line)
- item = {'key': key, 'value': result}
- results.append(item)
- return results
-
- for inference_text, _, _ in data_path_and_name_and_type:
- with open(inference_text, "r", encoding="utf-8") as fin:
- for line in fin:
- line = line.strip()
- segs = line.split("\t")
- if len(segs) != 2:
- continue
- key = segs[0]
- if len(segs[1]) == 0:
- continue
- result, _ = text2punc(segs[1])
- item = {'key': key, 'value': result}
- results.append(item)
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path != None:
- output_file_name = "infer.out"
- Path(output_path).mkdir(parents=True, exist_ok=True)
- output_file_path = (Path(output_path) / output_file_name).absolute()
- with open(output_file_path, "w", encoding="utf-8") as fout:
- for item_i in results:
- key_out = item_i["key"]
- value_out = item_i["value"]
- fout.write(f"{key_out}\t{value_out}\n")
- return results
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Punctuation inference",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
- group.add_argument("--raw_inputs", type=str, required=False)
- group.add_argument("--cache", type=list, required=False)
- group.add_argument("--param_dict", type=dict, required=False)
- group.add_argument("--key_file", type=str_or_none)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- # kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
deleted file mode 100644
index 0dc01f5..0000000
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ /dev/null
@@ -1,311 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-from pathlib import Path
-import sys
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Any
-from typing import List
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-
-from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.punctuation import PunctuationTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.datasets.preprocessor import split_to_mini_sentence
-
-
-class Text2Punc:
-
- def __init__(
- self,
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
- ):
- # Build Model
- model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
- self.device = device
- # Wrape model to make model.nll() data-parallel
- self.wrapped_model = ForwardAdaptor(model, "inference")
- self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- # logging.info(f"Model:\n{model}")
- self.punc_list = train_args.punc_list
- self.period = 0
- for i in range(len(self.punc_list)):
- if self.punc_list[i] == ",":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "?":
- self.punc_list[i] = "锛�"
- elif self.punc_list[i] == "銆�":
- self.period = i
- self.preprocessor = CodeMixTokenizerCommonPreprocessor(
- train=False,
- token_type=train_args.token_type,
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- )
-
-
- @torch.no_grad()
- def __call__(self, text: Union[list, str], cache: list, split_size=20):
- if cache is not None and len(cache) > 0:
- precache = "".join(cache)
- else:
- precache = ""
- cache = []
- data = {"text": precache + " " + text}
- result = self.preprocessor(data=data, uid="12938712838719")
- split_text = self.preprocessor.pop_split_text_data(result)
- mini_sentences = split_to_mini_sentence(split_text, split_size)
- mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
- assert len(mini_sentences) == len(mini_sentences_id)
- cache_sent = []
- cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
- sentence_punc_list = []
- sentence_words_list= []
- cache_pop_trigger_limit = 200
- skip_num = 0
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence_id = mini_sentences_id[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
- data = {
- "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
- "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
- }
- data = to_device(data, self.device)
- y, _ = self.wrapped_model(**data)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences) - 1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations) - 2, 1, -1):
- if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
-
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = self.period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
-
- punctuations_np = punctuations.cpu().numpy()
- sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
- sentence_words_list += mini_sentence
-
- assert len(sentence_punc_list) == len(sentence_words_list)
- words_with_punc = []
- sentence_punc_list_out = []
- for i in range(0, len(sentence_words_list)):
- if i > 0:
- if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
- sentence_words_list[i] = " " + sentence_words_list[i]
- if skip_num < len(cache):
- skip_num += 1
- else:
- words_with_punc.append(sentence_words_list[i])
- if skip_num >= len(cache):
- sentence_punc_list_out.append(sentence_punc_list[i])
- if sentence_punc_list[i] != "_":
- words_with_punc.append(sentence_punc_list[i])
- sentence_out = "".join(words_with_punc)
-
- sentenceEnd = -1
- for i in range(len(sentence_punc_list) - 2, 1, -1):
- if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
- sentenceEnd = i
- break
- cache_out = sentence_words_list[sentenceEnd + 1 :]
- if sentence_out[-1] in self.punc_list:
- sentence_out = sentence_out[:-1]
- sentence_punc_list_out[-1] = "_"
- return sentence_out, sentence_punc_list_out, cache_out
-
-
-def inference(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- output_dir: str,
- log_level: Union[int, str],
- train_config: Optional[str],
- model_file: Optional[str],
- key_file: Optional[str] = None,
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[List[Any], bytes, str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- output_dir=output_dir,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- train_config=train_config,
- model_file=model_file,
- param_dict=param_dict,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs, cache)
-
-
-def inference_modelscope(
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- #cache: list,
- key_file: Optional[str],
- train_config: Optional[str],
- model_file: Optional[str],
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
- text2punc = Text2Punc(train_config, model_file, device)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[List[Any], bytes, str] = None,
- output_dir_v2: Optional[str] = None,
- cache: List[Any] = None,
- param_dict: dict = None,
- ):
- results = []
- split_size = 10
- cache_in = param_dict["cache"]
- if raw_inputs != None:
- line = raw_inputs.strip()
- key = "demo"
- if line == "":
- item = {'key': key, 'value': ""}
- results.append(item)
- return results
- result, _, cache = text2punc(line, cache_in)
- param_dict["cache"] = cache
- item = {'key': key, 'value': result}
- results.append(item)
- return results
-
- return results
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Punctuation inference",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
- group.add_argument("--raw_inputs", type=str, required=False)
- group.add_argument("--cache", type=list, required=False)
- group.add_argument("--param_dict", type=dict, required=False)
- group.add_argument("--key_file", type=str_or_none)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument("--train_config", type=str)
- group.add_argument("--model_file", type=str)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- # kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
deleted file mode 100644
index c894f54..0000000
--- a/funasr/bin/sa_asr_inference.py
+++ /dev/null
@@ -1,687 +0,0 @@
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
-from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
-from funasr.modules.beam_search.beam_search_sa_asr import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.sa_asr import ASRTask
-from funasr.tasks.lm import LMTask
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tasks.asr import frontend_choices
-
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-
-class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- batch_size: int = 1,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- streaming: bool = False,
- frontend_conf: dict = None,
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend=='wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
- else:
- frontend_class=frontend_choices.get_class(asr_train_args.frontend)
- frontend = frontend_class(**asr_train_args.frontend_conf).eval()
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
-
- decoder = asr_model.decoder
-
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- token_list = asr_model.token_list
- scorers.update(
- decoder=decoder,
- ctc=ctc,
- length_bonus=LengthBonus(len(token_list)),
- )
-
- # 2. Build Language model
- if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, None, device
- )
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- # 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
-
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
- )
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
- self.beam_search = beam_search
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray], profile: Union[torch.Tensor, np.ndarray], profile_lengths: Union[torch.Tensor, np.ndarray]
- ) -> List[
- Tuple[
- Optional[str],
- Optional[str],
- List[str],
- List[int],
- Union[Hypothesis],
- ]
- ]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, text_id, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if isinstance(profile, np.ndarray):
- profile = torch.tensor(profile)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- asr_enc, _, spk_enc = self.asr_model.encode(**batch)
- if isinstance(asr_enc, tuple):
- asr_enc = asr_enc[0]
- if isinstance(spk_enc, tuple):
- spk_enc = spk_enc[0]
- assert len(asr_enc) == 1, len(asr_enc)
- assert len(spk_enc) == 1, len(spk_enc)
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- asr_enc[0], spk_enc[0], profile[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1: last_pos]
- else:
- token_int = hyp.yseq[1: last_pos].tolist()
-
- spk_weigths=torch.stack(hyp.spk_weigths, dim=0)
-
- token_ori = self.converter.ids2tokens(token_int)
- text_ori = self.tokenizer.tokens2text(token_ori)
-
- text_ori_spklist = text_ori.split('$')
- cur_index = 0
- spk_choose = []
- for i in range(len(text_ori_spklist)):
- text_ori_split = text_ori_spklist[i]
- n = len(text_ori_split)
- spk_weights_local = spk_weigths[cur_index: cur_index + n]
- cur_index = cur_index + n + 1
- spk_weights_local = spk_weights_local.mean(dim=0)
- spk_choose_local = spk_weights_local.argmax(-1)
- spk_choose.append(spk_choose_local.item() + 1)
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- text_spklist = text.split('$')
- assert len(spk_choose) == len(text_spklist)
-
- spk_list=[]
- for i in range(len(text_spklist)):
- text_split = text_spklist[i]
- n = len(text_split)
- spk_list.append(str(spk_choose[i]) * n)
-
- text_id = '$'.join(spk_list)
-
- assert len(text) == len(text_id)
-
- results.append((text, text_id, token, token_int, hyp))
-
- assert check_return_type(results)
- return results
-
-def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
- mc=mc,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- mc: bool = False,
- param_dict: dict = None,
- **kwargs,
-):
- assert check_argument_types()
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2text
- speech2text_kwargs = dict(
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- bpemodel=bpemodel,
- device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- dtype=dtype,
- beam_size=beam_size,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
- nbest=nbest,
- streaming=streaming,
- )
- logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
- speech2text = Speech2Text(**speech2text_kwargs)
-
- def _forward(data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- mc=mc,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
- # N-best list of (text, token, token_int, hyp_object)
- try:
- results = speech2text(**batch)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp]] * nbest
-
- # Only supporting batch_size==1
- key = keys[0]
- for n, (text, text_id, token, token_int, hyp) in zip(range(1, nbest + 1), results):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
-
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["text_id"][key] = text_id
-
- if text is not None:
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = text
-
- logging.info("uttid: {}".format(key))
- logging.info("text predictions: {}".format(text))
- logging.info("text_id predictions: {}\n".format(text_id))
- return asr_result_list
-
- return _forward
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
index 07b9b19..67106cf 100755
--- a/funasr/bin/sa_asr_train.py
+++ b/funasr/bin/sa_asr_train.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os
diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py
deleted file mode 100755
index c55bc35..0000000
--- a/funasr/bin/sond_inference.py
+++ /dev/null
@@ -1,577 +0,0 @@
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-from collections import OrderedDict
-import numpy as np
-import soundfile
-import torch
-from torch.nn import functional as F
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.diar import DiarTask
-from funasr.tasks.asr import ASRTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from scipy.ndimage import median_filter
-from funasr.utils.misc import statistic_model_parameters
-from funasr.datasets.iterable_dataset import load_bytes
-
-
-class Speech2Diarization:
- """Speech2Xvector class
-
- Examples:
- >>> import soundfile
- >>> import numpy as np
- >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
- >>> profile = np.load("profiles.npy")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2diar(audio, profile)
- {"spk1": [(int, int), ...], ...}
-
- """
-
- def __init__(
- self,
- diar_train_config: Union[Path, str] = None,
- diar_model_file: Union[Path, str] = None,
- device: Union[str, torch.device] = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- streaming: bool = False,
- smooth_size: int = 83,
- dur_threshold: float = 10,
- ):
- assert check_argument_types()
-
- # TODO: 1. Build Diarization model
- diar_model, diar_train_args = DiarTask.build_model_from_file(
- config_file=diar_train_config,
- model_file=diar_model_file,
- device=device
- )
- logging.info("diar_model: {}".format(diar_model))
- logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
- logging.info("diar_train_args: {}".format(diar_train_args))
- diar_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.diar_model = diar_model
- self.diar_train_args = diar_train_args
- self.token_list = diar_train_args.token_list
- self.smooth_size = smooth_size
- self.dur_threshold = dur_threshold
- self.device = device
- self.dtype = dtype
-
- def smooth_multi_labels(self, multi_label):
- multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
- return multi_label
-
- @staticmethod
- def calc_spk_turns(label_arr, spk_list):
- turn_list = []
- length = label_arr.shape[0]
- n_spk = label_arr.shape[1]
- for k in range(n_spk):
- if spk_list[k] == "None":
- continue
- in_utt = False
- start = 0
- for i in range(length):
- if label_arr[i, k] == 1 and in_utt is False:
- start = i
- in_utt = True
- if label_arr[i, k] == 0 and in_utt is True:
- turn_list.append([spk_list[k], start, i - start])
- in_utt = False
- if in_utt:
- turn_list.append([spk_list[k], start, length - start])
- return turn_list
-
- @staticmethod
- def seq2arr(seq, vec_dim=8):
- def int2vec(x, vec_dim=8, dtype=np.int):
- b = ('{:0' + str(vec_dim) + 'b}').format(x)
- # little-endian order: lower bit first
- return (np.array(list(b)[::-1]) == '1').astype(dtype)
-
- # process oov
- seq = np.array([int(x) for x in seq])
- new_seq = []
- for i, x in enumerate(seq):
- if x < 2 ** vec_dim:
- new_seq.append(x)
- else:
- idx_list = np.where(seq < 2 ** vec_dim)[0]
- idx = np.abs(idx_list - i).argmin()
- new_seq.append(seq[idx_list[idx]])
- return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
-
- def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
- logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T
- # upsampling outputs to match inputs
- ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
- logits_idx = F.upsample(
- logits_idx.unsqueeze(1).float(),
- size=(ut, ),
- mode="nearest",
- ).squeeze(1).long()
- logits_idx = logits_idx[0].tolist()
- pse_labels = [self.token_list[x] for x in logits_idx]
- if output_format == "pse_labels":
- return pse_labels, None
-
- multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers
- multi_labels = self.smooth_multi_labels(multi_labels)
- if output_format == "binary_labels":
- return multi_labels, None
-
- spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
- spk_turns = self.calc_spk_turns(multi_labels, spk_list)
- results = OrderedDict()
- for spk, st, dur in spk_turns:
- if spk not in results:
- results[spk] = []
- if dur > self.dur_threshold:
- results[spk].append((st, st+dur))
-
- # sort segments in start time ascending
- for spk in results:
- results[spk] = sorted(results[spk], key=lambda x: x[0])
-
- return results, pse_labels
-
- @torch.no_grad()
- def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- profile: Union[torch.Tensor, np.ndarray],
- output_format: str = "speaker_turn"
- ):
- """Inference
-
- Args:
- speech: Input speech data
- profile: Speaker profiles
- Returns:
- diarization results for each speaker
-
- """
- assert check_argument_types()
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if isinstance(profile, np.ndarray):
- profile = torch.tensor(profile)
-
- # data: (Nsamples,) -> (1, Nsamples)
- speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
- # lengths: (1,)
- speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
- batch = {"speech": speech, "speech_lengths": speech_lengths,
- "profile": profile, "profile_lengths": profile_lengths}
- # a. To device
- batch = to_device(batch, device=self.device)
-
- logits = self.diar_model.prediction_forward(**batch)
- results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
-
- return results, pse_labels
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ):
- """Build Speech2Xvector instance from the pretrained model.
-
- Args:
- model_tag (Optional[str]): Model tag of the pretrained models.
- Currently, the tags of espnet_model_zoo are supported.
-
- Returns:
- Speech2Xvector: Speech2Xvector instance.
-
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2Diarization(**kwargs)
-
-
-def inference_modelscope(
- diar_train_config: str,
- diar_model_file: str,
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 0,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- smooth_size: int = 83,
- dur_threshold: int = 10,
- out_format: str = "vad",
- param_dict: Optional[dict] = None,
- mode: str = "sond",
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("param_dict: {}".format(param_dict))
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2a. Build speech2xvec [Optional]
- if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]:
- assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict."
- assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict."
- sv_train_config = param_dict["sv_train_config"]
- sv_model_file = param_dict["sv_model_file"]
- if "model_dir" in param_dict:
- sv_train_config = os.path.join(param_dict["model_dir"], sv_train_config)
- sv_model_file = os.path.join(param_dict["model_dir"], sv_model_file)
- from funasr.bin.sv_inference import Speech2Xvector
- speech2xvector_kwargs = dict(
- sv_train_config=sv_train_config,
- sv_model_file=sv_model_file,
- device=device,
- dtype=dtype,
- streaming=streaming,
- embedding_node="resnet1_dense"
- )
- logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector.from_pretrained(
- model_tag=model_tag,
- **speech2xvector_kwargs,
- )
- speech2xvector.sv_model.eval()
-
- # 2b. Build speech2diar
- speech2diar_kwargs = dict(
- diar_train_config=diar_train_config,
- diar_model_file=diar_model_file,
- device=device,
- dtype=dtype,
- streaming=streaming,
- smooth_size=smooth_size,
- dur_threshold=dur_threshold,
- )
- logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2Diarization.from_pretrained(
- model_tag=model_tag,
- **speech2diar_kwargs,
- )
- speech2diar.diar_model.eval()
-
- def output_results_str(results: dict, uttid: str):
- rst = []
- mid = uttid.rsplit("-", 1)[0]
- for key in results:
- results[key] = [(x[0]/100, x[1]/100) for x in results[key]]
- if out_format == "vad":
- for spk, segs in results.items():
- rst.append("{} {}".format(spk, segs))
- else:
- template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
- for spk, segs in results.items():
- rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
-
- return "\n".join(rst)
-
- def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
- ):
- logging.info("param_dict: {}".format(param_dict))
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, (list, tuple)):
- if not isinstance(raw_inputs[0], List):
- raw_inputs = [raw_inputs]
-
- assert all([len(example) >= 2 for example in raw_inputs]), \
- "The length of test case in raw_inputs must larger than 1 (>=2)."
-
- def prepare_dataset():
- for idx, example in enumerate(raw_inputs):
- # read waveform file
- example = [load_bytes(x) if isinstance(x, bytes) else x
- for x in example]
- example = [soundfile.read(x)[0] if isinstance(x, str) else x
- for x in example]
- # convert torch tensor to numpy array
- example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
- for x in example]
- speech = example[0]
- logging.info("Extracting profiles for {} waveforms".format(len(example)-1))
- profile = [speech2xvector.calculate_embedding(x) for x in example[1:]]
- profile = torch.cat(profile, dim=0)
- yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]}
-
- loader = prepare_dataset()
- else:
- raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
- else:
- # 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 7. Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- os.makedirs(output_path, exist_ok=True)
- output_writer = open("{}/result.txt".format(output_path), "w")
- pse_label_writer = open("{}/labels.txt".format(output_path), "w")
- logging.info("Start to diarize...")
- result_list = []
- for idx, (keys, batch) in enumerate(loader):
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- results, pse_labels = speech2diar(**batch)
- # Only supporting batch_size==1
- key, value = keys[0], output_results_str(results, keys[0])
- item = {"key": key, "value": value}
- result_list.append(item)
- if output_path is not None:
- output_writer.write(value)
- output_writer.flush()
- pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels)))
- pse_label_writer.flush()
-
- if idx % 100 == 0:
- logging.info("Processing {:5d}: {}".format(idx, key))
-
- if output_path is not None:
- output_writer.close()
- pse_label_writer.close()
-
- return result_list
-
- return _forward
-
-
-def inference(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- diar_train_config: Optional[str],
- diar_model_file: Optional[str],
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 0,
- seed: int = 0,
- num_workers: int = 1,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- smooth_size: int = 83,
- dur_threshold: int = 10,
- out_format: str = "vad",
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- diar_train_config=diar_train_config,
- diar_model_file=diar_model_file,
- output_dir=output_dir,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- model_tag=model_tag,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- smooth_size=smooth_size,
- dur_threshold=dur_threshold,
- out_format=out_format,
- **kwargs,
- )
-
- return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Speaker verification/x-vector extraction",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--diar_train_config",
- type=str,
- help="diarization training configuration",
- )
- group.add_argument(
- "--diar_model_file",
- type=str,
- help="diarization model parameter file",
- )
- group.add_argument(
- "--dur_threshold",
- type=int,
- default=10,
- help="The threshold for short segments in number frames"
- )
- parser.add_argument(
- "--smooth_size",
- type=int,
- default=83,
- help="The smoothing window length in number frames"
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument("--streaming", type=str2bool, default=False)
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- logging.info("args: {}".format(kwargs))
- if args.output_dir is None:
- jobid, n_gpu = 1, 1
- gpuid = args.gpuid_list.split(",")[jobid-1]
- else:
- jobid = int(args.output_dir.split(".")[-1])
- n_gpu = len(args.gpuid_list.split(","))
- gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- results_list = inference(**kwargs)
- for results in results_list:
- print("{} {}".format(results["key"], results["value"]))
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
new file mode 100755
index 0000000..1517bfa
--- /dev/null
+++ b/funasr/bin/sv_infer.py
@@ -0,0 +1,163 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from kaldiio import WriteHelper
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.sv import SVTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.misc import statistic_model_parameters
+
+class Speech2Xvector:
+ """Speech2Xvector class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2xvector(audio)
+ [(text, token, token_int, hypothesis object), ...]
+
+ """
+
+ def __init__(
+ self,
+ sv_train_config: Union[Path, str] = None,
+ sv_model_file: Union[Path, str] = None,
+ device: str = "cpu",
+ batch_size: int = 1,
+ dtype: str = "float32",
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ ):
+ assert check_argument_types()
+
+ # TODO: 1. Build SV model
+ sv_model, sv_train_args = SVTask.build_model_from_file(
+ config_file=sv_train_config,
+ model_file=sv_model_file,
+ device=device
+ )
+ logging.info("sv_model: {}".format(sv_model))
+ logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
+ logging.info("sv_train_args: {}".format(sv_train_args))
+ sv_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.sv_model = sv_model
+ self.sv_train_args = sv_train_args
+ self.device = device
+ self.dtype = dtype
+ self.embedding_node = embedding_node
+
+ @torch.no_grad()
+ def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ # data: (Nsamples,) -> (1, Nsamples)
+ speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ # lengths: (1,)
+ lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ batch = {"speech": speech, "speech_lengths": lengths}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, ilens = self.sv_model.encode(**batch)
+
+ # c. Forward Pooling
+ pooling = self.sv_model.pooling_layer(enc)
+
+ # d. Forward Decoder
+ outputs, embeddings = self.sv_model.decoder(pooling)
+
+ if self.embedding_node not in embeddings:
+ raise ValueError("Required embedding node {} not in {}".format(
+ self.embedding_node, embeddings.keys()))
+
+ return embeddings[self.embedding_node]
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray],
+ ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ ref_speech: Reference speech to compare
+ Returns:
+ embedding, ref_embedding, similarity_score
+
+ """
+ assert check_argument_types()
+ self.sv_model.eval()
+ embedding = self.calculate_embedding(speech)
+ ref_emb, score = None, None
+ if ref_speech is not None:
+ ref_emb = self.calculate_embedding(ref_speech)
+ score = torch.cosine_similarity(embedding, ref_emb)
+
+ results = (embedding, ref_emb, score)
+ assert check_return_type(results)
+ return results
+
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ):
+ """Build Speech2Xvector instance from the pretrained model.
+
+ Args:
+ model_tag (Optional[str]): Model tag of the pretrained models.
+ Currently, the tags of espnet_model_zoo are supported.
+
+ Returns:
+ Speech2Xvector: Speech2Xvector instance.
+
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2Xvector(**kwargs)
+
+
+
+
diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
deleted file mode 100755
index 76b1dfb..0000000
--- a/funasr/bin/sv_inference.py
+++ /dev/null
@@ -1,443 +0,0 @@
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from kaldiio import WriteHelper
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
-from funasr.tasks.asr import ASRTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils.misc import statistic_model_parameters
-
-class Speech2Xvector:
- """Speech2Xvector class
-
- Examples:
- >>> import soundfile
- >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2xvector(audio)
- [(text, token, token_int, hypothesis object), ...]
-
- """
-
- def __init__(
- self,
- sv_train_config: Union[Path, str] = None,
- sv_model_file: Union[Path, str] = None,
- device: str = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- ):
- assert check_argument_types()
-
- # TODO: 1. Build SV model
- sv_model, sv_train_args = SVTask.build_model_from_file(
- config_file=sv_train_config,
- model_file=sv_model_file,
- device=device
- )
- logging.info("sv_model: {}".format(sv_model))
- logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
- logging.info("sv_train_args: {}".format(sv_train_args))
- sv_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.sv_model = sv_model
- self.sv_train_args = sv_train_args
- self.device = device
- self.dtype = dtype
- self.embedding_node = embedding_node
-
- @torch.no_grad()
- def calculate_embedding(self, speech: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- # data: (Nsamples,) -> (1, Nsamples)
- speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- # lengths: (1,)
- lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
- batch = {"speech": speech, "speech_lengths": lengths}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, ilens = self.sv_model.encode(**batch)
-
- # c. Forward Pooling
- pooling = self.sv_model.pooling_layer(enc)
-
- # d. Forward Decoder
- outputs, embeddings = self.sv_model.decoder(pooling)
-
- if self.embedding_node not in embeddings:
- raise ValueError("Required embedding node {} not in {}".format(
- self.embedding_node, embeddings.keys()))
-
- return embeddings[self.embedding_node]
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray],
- ref_speech: Optional[Union[torch.Tensor, np.ndarray]] = None,
- ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], Union[torch.Tensor, None]]:
- """Inference
-
- Args:
- speech: Input speech data
- ref_speech: Reference speech to compare
- Returns:
- embedding, ref_embedding, similarity_score
-
- """
- assert check_argument_types()
- self.sv_model.eval()
- embedding = self.calculate_embedding(speech)
- ref_emb, score = None, None
- if ref_speech is not None:
- ref_emb = self.calculate_embedding(ref_speech)
- score = torch.cosine_similarity(embedding, ref_emb)
-
- results = (embedding, ref_emb, score)
- assert check_return_type(results)
- return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ):
- """Build Speech2Xvector instance from the pretrained model.
-
- Args:
- model_tag (Optional[str]): Model tag of the pretrained models.
- Currently, the tags of espnet_model_zoo are supported.
-
- Returns:
- Speech2Xvector: Speech2Xvector instance.
-
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2Xvector(**kwargs)
-
-
-def inference_modelscope(
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- sv_train_config: Optional[str] = "sv.yaml",
- sv_model_file: Optional[str] = "sv.pb",
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- param_dict: Optional[dict] = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("param_dict: {}".format(param_dict))
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2xvector
- speech2xvector_kwargs = dict(
- sv_train_config=sv_train_config,
- sv_model_file=sv_model_file,
- device=device,
- dtype=dtype,
- streaming=streaming,
- embedding_node=embedding_node
- )
- logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector.from_pretrained(
- model_tag=model_tag,
- **speech2xvector_kwargs,
- )
- speech2xvector.sv_model.eval()
-
- def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
- ):
- logging.info("param_dict: {}".format(param_dict))
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
- # 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- # 7 .Start for-loop
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- embd_writer, ref_embd_writer, score_writer = None, None, None
- if output_path is not None:
- os.makedirs(output_path, exist_ok=True)
- embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
- sv_result_list = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
- embedding, ref_embedding, score = speech2xvector(**batch)
- # Only supporting batch_size==1
- key = keys[0]
- normalized_score = 0.0
- if score is not None:
- score = score.item()
- normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
- item = {"key": key, "value": normalized_score}
- else:
- item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
- sv_result_list.append(item)
- if output_path is not None:
- embd_writer(key, embedding[0].cpu().numpy())
- if ref_embedding is not None:
- if ref_embd_writer is None:
- ref_embd_writer = WriteHelper(
- "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
- )
- score_writer = open(os.path.join(output_path, "score.txt"), "w")
- ref_embd_writer(key, ref_embedding[0].cpu().numpy())
- score_writer.write("{} {:.6f}\n".format(key, normalized_score))
-
- if output_path is not None:
- embd_writer.close()
- if ref_embd_writer is not None:
- ref_embd_writer.close()
- score_writer.close()
-
- return sv_result_list
-
- return _forward
-
-
-def inference(
- output_dir: Optional[str],
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
- key_file: Optional[str],
- sv_train_config: Optional[str],
- sv_model_file: Optional[str],
- model_tag: Optional[str],
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- output_dir=output_dir,
- batch_size=batch_size,
- dtype=dtype,
- ngpu=ngpu,
- seed=seed,
- num_workers=num_workers,
- log_level=log_level,
- key_file=key_file,
- sv_train_config=sv_train_config,
- sv_model_file=sv_model_file,
- model_tag=model_tag,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- embedding_node=embedding_node,
- sv_threshold=sv_threshold,
- **kwargs,
- )
-
- return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Speaker verification/x-vector extraction",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--sv_train_config",
- type=str,
- help="SV training configuration",
- )
- group.add_argument(
- "--sv_model_file",
- type=str,
- help="SV model parameter file",
- )
- group.add_argument(
- "--sv_threshold",
- type=float,
- default=0.9465,
- help="The threshold for verification"
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- parser.add_argument("--streaming", type=str2bool, default=False)
- parser.add_argument("--embedding_node", type=str, default="resnet1_dense")
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- logging.info("args: {}".format(kwargs))
- if args.output_dir is None:
- jobid, n_gpu = 1, 1
- gpuid = args.gpuid_list.split(",")[jobid-1]
- else:
- jobid = int(args.output_dir.split(".")[-1])
- n_gpu = len(args.gpuid_list.split(","))
- gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- results_list = inference(**kwargs)
- for results in results_list:
- print("{} {}".format(results["key"], results["value"]))
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 8806070..dbddd9f 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,7 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-
import argparse
import logging
@@ -14,7 +14,173 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import numpy as np
+import torch
+from kaldiio import WriteHelper
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.sv import SVTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.misc import statistic_model_parameters
+from funasr.bin.sv_infer import Speech2Xvector
+
+def inference_sv(
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ sv_train_config: Optional[str] = "sv.yaml",
+ sv_model_file: Optional[str] = "sv.pb",
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ sv_threshold: float = 0.9465,
+ param_dict: Optional[dict] = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("param_dict: {}".format(param_dict))
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2xvector
+ speech2xvector_kwargs = dict(
+ sv_train_config=sv_train_config,
+ sv_model_file=sv_model_file,
+ device=device,
+ dtype=dtype,
+ streaming=streaming,
+ embedding_node=embedding_node
+ )
+ logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
+ speech2xvector = Speech2Xvector.from_pretrained(
+ model_tag=model_tag,
+ **speech2xvector_kwargs,
+ )
+ speech2xvector.sv_model.eval()
+
+ def _forward(
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
+ ):
+ logging.info("param_dict: {}".format(param_dict))
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+
+ # 3. Build data-iterator
+ loader = SVTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=None,
+ collate_fn=None,
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ # 7 .Start for-loop
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ embd_writer, ref_embd_writer, score_writer = None, None, None
+ if output_path is not None:
+ os.makedirs(output_path, exist_ok=True)
+ embd_writer = WriteHelper("ark,scp:{}/xvector.ark,{}/xvector.scp".format(output_path, output_path))
+ sv_result_list = []
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+ embedding, ref_embedding, score = speech2xvector(**batch)
+ # Only supporting batch_size==1
+ key = keys[0]
+ normalized_score = 0.0
+ if score is not None:
+ score = score.item()
+ normalized_score = max(score - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+ item = {"key": key, "value": normalized_score}
+ else:
+ item = {"key": key, "value": embedding.squeeze(0).cpu().numpy()}
+ sv_result_list.append(item)
+ if output_path is not None:
+ embd_writer(key, embedding[0].cpu().numpy())
+ if ref_embedding is not None:
+ if ref_embd_writer is None:
+ ref_embd_writer = WriteHelper(
+ "ark,scp:{}/ref_xvector.ark,{}/ref_xvector.scp".format(output_path, output_path)
+ )
+ score_writer = open(os.path.join(output_path, "score.txt"), "w")
+ ref_embd_writer(key, ref_embedding[0].cpu().numpy())
+ score_writer.write("{} {:.6f}\n".format(key, normalized_score))
+
+ if output_path is not None:
+ embd_writer.close()
+ if ref_embd_writer is not None:
+ ref_embd_writer.close()
+ score_writer.close()
+
+ return sv_result_list
+
+ return _forward
+
+
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "sv":
+ return inference_sv(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -131,15 +297,6 @@
return parser
-def inference_launch(mode, **kwargs):
- if mode == "sv":
- from funasr.bin.sv_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
-
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
@@ -167,7 +324,8 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- inference_launch(**kwargs)
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
new file mode 100644
index 0000000..4ddcba4
--- /dev/null
+++ b/funasr/bin/tp_infer.py
@@ -0,0 +1,120 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+from optparse import Option
+import sys
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.tasks.asr import ASRTaskAligner as ASRTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+
+
+
+
+class Speech2Timestamp:
+ def __init__(
+ self,
+ timestamp_infer_config: Union[Path, str] = None,
+ timestamp_model_file: Union[Path, str] = None,
+ timestamp_cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ dtype: str = "float32",
+ **kwargs,
+ ):
+ assert check_argument_types()
+ # 1. Build ASR model
+ tp_model, tp_train_args = ASRTask.build_model_from_file(
+ timestamp_infer_config, timestamp_model_file, device=device
+ )
+ if 'cuda' in device:
+ tp_model = tp_model.cuda() # force model to cuda
+
+ frontend = None
+ if tp_train_args.frontend is not None:
+ frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
+
+ logging.info("tp_model: {}".format(tp_model))
+ logging.info("tp_train_args: {}".format(tp_train_args))
+ tp_model.to(dtype=getattr(torch, dtype)).eval()
+
+ logging.info(f"Decoding device={device}, dtype={dtype}")
+
+
+ self.tp_model = tp_model
+ self.tp_train_args = tp_train_args
+
+ token_list = self.tp_model.token_list
+ self.converter = TokenIDConverter(token_list=token_list)
+
+ self.device = device
+ self.dtype = dtype
+ self.frontend = frontend
+ self.encoder_downsampling_factor = 1
+ if tp_train_args.encoder_conf["input_layer"] == "conv2d":
+ self.encoder_downsampling_factor = 4
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ text_lengths: Union[torch.Tensor, np.ndarray] = None
+ ):
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.tp_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+
+ # lfr_factor = max(1, (feats.size()[-1]//80)-1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
+
+ # a. To device
+ batch = to_device(batch, device=self.device)
+
+ # b. Forward Encoder
+ enc, enc_len = self.tp_model.encode(**batch)
+ if isinstance(enc, tuple):
+ enc = enc[0]
+
+ # c. Forward Predictor
+ _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
+ return us_alphas, us_peaks
+
+
+
diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py
deleted file mode 100644
index 6e513c5..0000000
--- a/funasr/bin/tp_inference.py
+++ /dev/null
@@ -1,399 +0,0 @@
-import argparse
-import logging
-from optparse import Option
-import sys
-import json
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.tasks.asr import ASRTaskAligner as ASRTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.text.token_id_converter import TokenIDConverter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
- 'audio_fs': 16000,
- 'model_fs': 16000
-}
-
-
-class SpeechText2Timestamp:
- def __init__(
- self,
- timestamp_infer_config: Union[Path, str] = None,
- timestamp_model_file: Union[Path, str] = None,
- timestamp_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- dtype: str = "float32",
- **kwargs,
- ):
- assert check_argument_types()
- # 1. Build ASR model
- tp_model, tp_train_args = ASRTask.build_model_from_file(
- timestamp_infer_config, timestamp_model_file, device=device
- )
- if 'cuda' in device:
- tp_model = tp_model.cuda() # force model to cuda
-
- frontend = None
- if tp_train_args.frontend is not None:
- frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
-
- logging.info("tp_model: {}".format(tp_model))
- logging.info("tp_train_args: {}".format(tp_train_args))
- tp_model.to(dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
-
- self.tp_model = tp_model
- self.tp_train_args = tp_train_args
-
- token_list = self.tp_model.token_list
- self.converter = TokenIDConverter(token_list=token_list)
-
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if tp_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self,
- speech: Union[torch.Tensor, np.ndarray],
- speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- text_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.tp_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
-
- # lfr_factor = max(1, (feats.size()[-1]//80)-1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, enc_len = self.tp_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
-
- # c. Forward Predictor
- _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
- return us_alphas, us_peaks
-
-
-def inference(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- timestamp_infer_config: Optional[str],
- timestamp_model_file: Optional[str],
- timestamp_cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- split_with_space: bool = True,
- seg_dict_file: Optional[str] = None,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- batch_size=batch_size,
- ngpu=ngpu,
- log_level=log_level,
- timestamp_infer_config=timestamp_infer_config,
- timestamp_model_file=timestamp_model_file,
- timestamp_cmvn_file=timestamp_cmvn_file,
- key_file=key_file,
- allow_variable_data_keys=allow_variable_data_keys,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- num_workers=num_workers,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- timestamp_infer_config: Optional[str],
- timestamp_model_file: Optional[str],
- timestamp_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- split_with_space: bool = True,
- seg_dict_file: Optional[str] = None,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
- if ngpu > 1:
- raise NotImplementedError("only single GPU decoding is supported")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speechtext2timestamp_kwargs = dict(
- timestamp_infer_config=timestamp_infer_config,
- timestamp_model_file=timestamp_model_file,
- timestamp_cmvn_file=timestamp_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
- speechtext2timestamp = SpeechText2Timestamp(**speechtext2timestamp_kwargs)
-
- preprocessor = LMPreprocessor(
- train=False,
- token_type=speechtext2timestamp.tp_train_args.token_type,
- token_list=speechtext2timestamp.tp_train_args.token_list,
- bpemodel=None,
- text_cleaner=None,
- g2p_type=None,
- text_name="text",
- non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
- split_with_space=split_with_space,
- seg_dict_file=seg_dict_file,
- )
-
- if output_dir is not None:
- writer = DatadirWriter(output_dir)
- tp_writer = writer[f"timestamp_prediction"]
- # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
- else:
- tp_writer = None
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs
- ):
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- writer = None
- if output_path is not None:
- writer = DatadirWriter(output_path)
- tp_writer = writer[f"timestamp_prediction"]
- else:
- tp_writer = None
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=preprocessor,
- collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- tp_result_list = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- logging.info("timestamp predicting, utt_id: {}".format(keys))
- _batch = {'speech':batch['speech'],
- 'speech_lengths':batch['speech_lengths'],
- 'text_lengths':batch['text_lengths']}
- us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
-
- for batch_id in range(_bs):
- key = keys[batch_id]
- token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
- ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
- logging.warning(ts_str)
- item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
- if tp_writer is not None:
- tp_writer["tp_sync"][key+'#'] = ts_str
- tp_writer["tp_time"][key+'#'] = str(ts_list)
- tp_result_list.append(item)
- return tp_result_list
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="Timestamp Prediction Inference",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=0,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--timestamp_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--timestamp_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--timestamp_cmvn_file",
- type=str,
- help="Global cmvn file",
- )
-
- group = parser.add_argument_group("infer related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument(
- "--seg_dict_file",
- type=str,
- default=None,
- help="The batch size for inference",
- )
- group.add_argument(
- "--split_with_space",
- type=bool,
- default=False,
- help="The batch size for inference",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py
index 6cdff05..a8d67ef 100644
--- a/funasr/bin/tp_inference_launch.py
+++ b/funasr/bin/tp_inference_launch.py
@@ -1,4 +1,7 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import argparse
@@ -13,6 +16,180 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+import argparse
+import logging
+from optparse import Option
+import sys
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.tasks.asr import ASRTaskAligner as ASRTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.bin.tp_infer import Speech2Timestamp
+
+def inference_tp(
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ timestamp_infer_config: Optional[str],
+ timestamp_model_file: Optional[str],
+ timestamp_cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ split_with_space: bool = True,
+ seg_dict_file: Optional[str] = None,
+ **kwargs,
+):
+ assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2vadsegment
+ speechtext2timestamp_kwargs = dict(
+ timestamp_infer_config=timestamp_infer_config,
+ timestamp_model_file=timestamp_model_file,
+ timestamp_cmvn_file=timestamp_cmvn_file,
+ device=device,
+ dtype=dtype,
+ )
+ logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
+ speechtext2timestamp = Speech2Timestamp(**speechtext2timestamp_kwargs)
+
+ preprocessor = LMPreprocessor(
+ train=False,
+ token_type=speechtext2timestamp.tp_train_args.token_type,
+ token_list=speechtext2timestamp.tp_train_args.token_list,
+ bpemodel=None,
+ text_cleaner=None,
+ g2p_type=None,
+ text_name="text",
+ non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ )
+
+ if output_dir is not None:
+ writer = DatadirWriter(output_dir)
+ tp_writer = writer[f"timestamp_prediction"]
+ # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
+ else:
+ tp_writer = None
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs
+ ):
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ writer = None
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ tp_writer = writer[f"timestamp_prediction"]
+ else:
+ tp_writer = None
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+
+ loader = ASRTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=preprocessor,
+ collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ tp_result_list = []
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ logging.info("timestamp predicting, utt_id: {}".format(keys))
+ _batch = {'speech': batch['speech'],
+ 'speech_lengths': batch['speech_lengths'],
+ 'text_lengths': batch['text_lengths']}
+ us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
+
+ for batch_id in range(_bs):
+ key = keys[batch_id]
+ token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
+ ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token,
+ force_time_shift=-3.0)
+ logging.warning(ts_str)
+ item = {'key': key, 'value': ts_str, 'timestamp': ts_list}
+ if tp_writer is not None:
+ tp_writer["tp_sync"][key + '#'] = ts_str
+ tp_writer["tp_time"][key + '#'] = str(ts_list)
+ tp_result_list.append(item)
+ return tp_result_list
+
+ return _forward
+
+
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "tp_norm":
+ return inference_tp(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -100,14 +277,6 @@
return parser
-def inference_launch(mode, **kwargs):
- if mode == "tp_norm":
- from funasr.bin.tp_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
@@ -135,7 +304,9 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- inference_launch(**kwargs)
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"])
+
if __name__ == "__main__":
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
new file mode 100755
index 0000000..b0d46e7
--- /dev/null
+++ b/funasr/bin/train.py
@@ -0,0 +1,572 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+import os
+import sys
+from io import BytesIO
+
+import torch
+
+from funasr.build_utils.build_args import build_args
+from funasr.build_utils.build_dataloader import build_dataloader
+from funasr.build_utils.build_distributed import build_distributed
+from funasr.build_utils.build_model import build_model
+from funasr.build_utils.build_optimizer import build_optimizer
+from funasr.build_utils.build_scheduler import build_scheduler
+from funasr.build_utils.build_trainer import build_trainer
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+from funasr.torch_utils.model_summary import model_summary
+from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.prepare_data import prepare_data
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="FunASR Common Training Parser",
+ )
+
+ # common configuration
+ parser.add_argument("--output_dir", help="model save path")
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
+
+ # ddp related
+ parser.add_argument(
+ "--dist_backend",
+ default="nccl",
+ type=str,
+ help="distributed backend",
+ )
+ parser.add_argument(
+ "--dist_init_method",
+ type=str,
+ default="env://",
+ help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
+ '"WORLD_SIZE", and "RANK" are referred.',
+ )
+ parser.add_argument(
+ "--dist_world_size",
+ type=int,
+ default=1,
+ help="number of nodes for distributed training",
+ )
+ parser.add_argument(
+ "--dist_rank",
+ type=int,
+ default=None,
+ help="node rank for distributed training",
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ default=None,
+ help="local rank for distributed training",
+ )
+ parser.add_argument(
+ "--dist_master_addr",
+ default=None,
+ type=str_or_none,
+ help="The master address for distributed training. "
+ "This value is used when dist_init_method == 'env://'",
+ )
+ parser.add_argument(
+ "--dist_master_port",
+ default=None,
+ type=int_or_none,
+ help="The master port for distributed training"
+ "This value is used when dist_init_method == 'env://'",
+ )
+ parser.add_argument(
+ "--dist_launcher",
+ default=None,
+ type=str_or_none,
+ choices=["slurm", "mpi", None],
+ help="The launcher type for distributed training",
+ )
+ parser.add_argument(
+ "--multiprocessing_distributed",
+ default=True,
+ type=str2bool,
+ help="Use multi-processing distributed training to launch "
+ "N processes per node, which has N GPUs. This is the "
+ "fastest way to use PyTorch for either single node or "
+ "multi node data parallel training",
+ )
+ parser.add_argument(
+ "--unused_parameters",
+ type=str2bool,
+ default=False,
+ help="Whether to use the find_unused_parameters in "
+ "torch.nn.parallel.DistributedDataParallel ",
+ )
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+
+ # cudnn related
+ parser.add_argument(
+ "--cudnn_enabled",
+ type=str2bool,
+ default=torch.backends.cudnn.enabled,
+ help="Enable CUDNN",
+ )
+ parser.add_argument(
+ "--cudnn_benchmark",
+ type=str2bool,
+ default=torch.backends.cudnn.benchmark,
+ help="Enable cudnn-benchmark mode",
+ )
+ parser.add_argument(
+ "--cudnn_deterministic",
+ type=str2bool,
+ default=True,
+ help="Enable cudnn-deterministic mode",
+ )
+
+ # trainer related
+ parser.add_argument(
+ "--max_epoch",
+ type=int,
+ default=40,
+ help="The maximum number epoch to train",
+ )
+ parser.add_argument(
+ "--max_update",
+ type=int,
+ default=sys.maxsize,
+ help="The maximum number update step to train",
+ )
+ parser.add_argument(
+ "--batch_interval",
+ type=int,
+ default=10000,
+ help="The batch interval for saving model.",
+ )
+ parser.add_argument(
+ "--patience",
+ type=int_or_none,
+ default=None,
+ help="Number of epochs to wait without improvement "
+ "before stopping the training",
+ )
+ parser.add_argument(
+ "--val_scheduler_criterion",
+ type=str,
+ nargs=2,
+ default=("valid", "loss"),
+ help="The criterion used for the value given to the lr scheduler. "
+ 'Give a pair referring the phase, "train" or "valid",'
+ 'and the criterion name. The mode specifying "min" or "max" can '
+ "be changed by --scheduler_conf",
+ )
+ parser.add_argument(
+ "--early_stopping_criterion",
+ type=str,
+ nargs=3,
+ default=("valid", "loss", "min"),
+ help="The criterion used for judging of early stopping. "
+ 'Give a pair referring the phase, "train" or "valid",'
+ 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
+ )
+ parser.add_argument(
+ "--best_model_criterion",
+ nargs="+",
+ default=[
+ ("train", "loss", "min"),
+ ("valid", "loss", "min"),
+ ("train", "acc", "max"),
+ ("valid", "acc", "max"),
+ ],
+ help="The criterion used for judging of the best model. "
+ 'Give a pair referring the phase, "train" or "valid",'
+ 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
+ )
+ parser.add_argument(
+ "--keep_nbest_models",
+ type=int,
+ nargs="+",
+ default=[10],
+ help="Remove previous snapshots excluding the n-best scored epochs",
+ )
+ parser.add_argument(
+ "--nbest_averaging_interval",
+ type=int,
+ default=0,
+ help="The epoch interval to apply model averaging and save nbest models",
+ )
+ parser.add_argument(
+ "--grad_clip",
+ type=float,
+ default=5.0,
+ help="Gradient norm threshold to clip",
+ )
+ parser.add_argument(
+ "--grad_clip_type",
+ type=float,
+ default=2.0,
+ help="The type of the used p-norm for gradient clip. Can be inf",
+ )
+ parser.add_argument(
+ "--grad_noise",
+ type=str2bool,
+ default=False,
+ help="The flag to switch to use noise injection to "
+ "gradients during training",
+ )
+ parser.add_argument(
+ "--accum_grad",
+ type=int,
+ default=1,
+ help="The number of gradient accumulation",
+ )
+ parser.add_argument(
+ "--resume",
+ type=str2bool,
+ default=False,
+ help="Enable resuming if checkpoint is existing",
+ )
+ parser.add_argument(
+ "--train_dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type for training.",
+ )
+ parser.add_argument(
+ "--use_amp",
+ type=str2bool,
+ default=False,
+ help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
+ )
+ parser.add_argument(
+ "--log_interval",
+ default=None,
+ help="Show the logs every the number iterations in each epochs at the "
+ "training phase. If None is given, it is decided according the number "
+ "of training samples automatically .",
+ )
+ parser.add_argument(
+ "--use_tensorboard",
+ type=str2bool,
+ default=True,
+ help="Enable tensorboard logging",
+ )
+
+ # pretrained model related
+ parser.add_argument(
+ "--init_param",
+ type=str,
+ default=[],
+ nargs="*",
+ help="Specify the file path used for initialization of parameters. "
+ "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
+ "where file_path is the model file path, "
+ "src_key specifies the key of model states to be used in the model file, "
+ "dst_key specifies the attribute of the model to be initialized, "
+ "and exclude_keys excludes keys of model states for the initialization."
+ "e.g.\n"
+ " # Load all parameters"
+ " --init_param some/where/model.pb\n"
+ " # Load only decoder parameters"
+ " --init_param some/where/model.pb:decoder:decoder\n"
+ " # Load only decoder parameters excluding decoder.embed"
+ " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+ " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
+ )
+ parser.add_argument(
+ "--ignore_init_mismatch",
+ type=str2bool,
+ default=False,
+ help="Ignore size mismatch when loading pre-trained model",
+ )
+ parser.add_argument(
+ "--freeze_param",
+ type=str,
+ default=[],
+ nargs="*",
+ help="Freeze parameters",
+ )
+
+ # dataset related
+ parser.add_argument(
+ "--dataset_type",
+ type=str,
+ default="small",
+ help="whether to use dataloader for large dataset",
+ )
+ parser.add_argument(
+ "--dataset_conf",
+ action=NestedDictAction,
+ default=dict(),
+ help=f"The keyword arguments for dataset",
+ )
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default=None,
+ help="root path of data",
+ )
+ parser.add_argument(
+ "--train_set",
+ type=str,
+ default="train",
+ help="train dataset",
+ )
+ parser.add_argument(
+ "--valid_set",
+ type=str,
+ default="validation",
+ help="dev dataset",
+ )
+ parser.add_argument(
+ "--data_file_names",
+ type=str,
+ default="wav.scp,text",
+ help="input data files",
+ )
+ parser.add_argument(
+ "--speed_perturb",
+ type=float,
+ nargs="+",
+ default=None,
+ help="speed perturb",
+ )
+ parser.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+
+ # optimization related
+ parser.add_argument(
+ "--optim",
+ type=lambda x: x.lower(),
+ default="adam",
+ help="The optimizer type",
+ )
+ parser.add_argument(
+ "--optim_conf",
+ action=NestedDictAction,
+ default=dict(),
+ help="The keyword arguments for optimizer",
+ )
+ parser.add_argument(
+ "--scheduler",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The lr scheduler type",
+ )
+ parser.add_argument(
+ "--scheduler_conf",
+ action=NestedDictAction,
+ default=dict(),
+ help="The keyword arguments for lr scheduler",
+ )
+
+ # most task related
+ parser.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+ parser.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ parser.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word"],
+ help="",
+ )
+ parser.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file fo sentencepiece",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+
+ # pai related
+ parser.add_argument(
+ "--use_pai",
+ type=str2bool,
+ default=False,
+ help="flag to indicate whether training on PAI",
+ )
+ parser.add_argument(
+ "--simple_ddp",
+ type=str2bool,
+ default=False,
+ )
+ parser.add_argument(
+ "--num_worker_count",
+ type=int,
+ default=1,
+ help="The number of machines on PAI.",
+ )
+ parser.add_argument(
+ "--access_key_id",
+ type=str,
+ default=None,
+ help="The username for oss.",
+ )
+ parser.add_argument(
+ "--access_key_secret",
+ type=str,
+ default=None,
+ help="The password for oss.",
+ )
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default=None,
+ help="The endpoint for oss.",
+ )
+ parser.add_argument(
+ "--bucket_name",
+ type=str,
+ default=None,
+ help="The bucket name for oss.",
+ )
+ parser.add_argument(
+ "--oss_bucket",
+ default=None,
+ help="oss bucket.",
+ )
+
+ return parser
+
+
+if __name__ == '__main__':
+ parser = get_parser()
+ args, extra_task_params = parser.parse_known_args()
+ if extra_task_params:
+ args = build_args(args, parser, extra_task_params)
+
+ # set random seed
+ set_all_random_seed(args.seed)
+ torch.backends.cudnn.enabled = args.cudnn_enabled
+ torch.backends.cudnn.benchmark = args.cudnn_benchmark
+ torch.backends.cudnn.deterministic = args.cudnn_deterministic
+
+ # ddp init
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+ args.distributed = args.ngpu > 1 or args.dist_world_size > 1
+ distributed_option = build_distributed(args)
+
+ # for logging
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ logging.basicConfig(
+ level="INFO",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ else:
+ logging.basicConfig(
+ level="ERROR",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ # prepare files for dataloader
+ prepare_data(args, distributed_option)
+
+ model = build_model(args)
+ model = model.to(
+ dtype=getattr(torch, args.train_dtype),
+ device="cuda" if args.ngpu > 0 else "cpu",
+ )
+ optimizers = build_optimizer(args, model=model)
+ schedulers = build_scheduler(args, optimizers)
+
+ logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+ distributed_option.dist_rank,
+ distributed_option.local_rank))
+ logging.info(pytorch_cudnn_version())
+ logging.info("Args: {}".format(args))
+ logging.info(model_summary(model))
+ logging.info("Optimizer: {}".format(optimizers))
+ logging.info("Scheduler: {}".format(schedulers))
+
+ # dump args to config.yaml
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ os.makedirs(args.output_dir, exist_ok=True)
+ with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
+ logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
+ if args.use_pai:
+ buffer = BytesIO()
+ torch.save({"config": vars(args)}, buffer)
+ args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
+ else:
+ yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
+
+ for p in args.init_param:
+ logging.info(f"Loading pretrained params from {p}")
+ load_pretrained_model(
+ model=model,
+ init_param=p,
+ ignore_init_mismatch=args.ignore_init_mismatch,
+ map_location=f"cuda:{torch.cuda.current_device()}"
+ if args.ngpu > 0
+ else "cpu",
+ oss_bucket=args.oss_bucket,
+ )
+
+ # dataloader for training/validation
+ train_dataloader, valid_dataloader = build_dataloader(args)
+
+ # Trainer, including model, optimizers, etc.
+ trainer = build_trainer(
+ args=args,
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ train_dataloader=train_dataloader,
+ valid_dataloader=valid_dataloader,
+ distributed_option=distributed_option
+ )
+
+ trainer.run()
diff --git a/funasr/bin/vad_infer.py b/funasr/bin/vad_infer.py
new file mode 100644
index 0000000..245757c
--- /dev/null
+++ b/funasr/bin/vad_infer.py
@@ -0,0 +1,201 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import math
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.vad import VADTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+
+
+
+class Speech2VadSegment:
+ """Speech2VadSegment class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2segment(audio)
+ [[10, 230], [245, 450], ...]
+
+ """
+
+ def __init__(
+ self,
+ vad_infer_config: Union[Path, str] = None,
+ vad_model_file: Union[Path, str] = None,
+ vad_cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ batch_size: int = 1,
+ dtype: str = "float32",
+ **kwargs,
+ ):
+ assert check_argument_types()
+
+ # 1. Build vad model
+ vad_model, vad_infer_args = VADTask.build_model_from_file(
+ vad_infer_config, vad_model_file, device
+ )
+ frontend = None
+ if vad_infer_args.frontend is not None:
+ frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
+
+ logging.info("vad_model: {}".format(vad_model))
+ logging.info("vad_infer_args: {}".format(vad_infer_args))
+ vad_model.to(dtype=getattr(torch, dtype)).eval()
+
+ self.vad_model = vad_model
+ self.vad_infer_args = vad_infer_args
+ self.device = device
+ self.dtype = dtype
+ self.frontend = frontend
+ self.batch_size = batch_size
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ in_cache: Dict[str, torch.Tensor] = dict()
+ ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ self.frontend.filter_length_max = math.inf
+ fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
+ feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
+ fbanks = to_device(fbanks, device=self.device)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ else:
+ raise Exception("Need to extract feats first, please configure frontend configuration")
+
+ # b. Forward Encoder streaming
+ t_offset = 0
+ step = min(feats_len.max(), 6000)
+ segments = [[]] * self.batch_size
+ for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
+ if t_offset + step >= feats_len - 1:
+ step = feats_len - t_offset
+ is_final = True
+ else:
+ is_final = False
+ batch = {
+ "feats": feats[:, t_offset:t_offset + step, :],
+ "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
+ "is_final": is_final,
+ "in_cache": in_cache
+ }
+ # a. To device
+ #batch = to_device(batch, device=self.device)
+ segments_part, in_cache = self.vad_model(**batch)
+ if segments_part:
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
+ return fbanks, segments
+
+class Speech2VadSegmentOnline(Speech2VadSegment):
+ """Speech2VadSegmentOnline class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2segment(audio)
+ [[10, 230], [245, 450], ...]
+
+ """
+ def __init__(self, **kwargs):
+ super(Speech2VadSegmentOnline, self).__init__(**kwargs)
+ vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
+ self.frontend = None
+ if self.vad_infer_args.frontend is not None:
+ self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
+
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
+ ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ batch_size = speech.shape[0]
+ segments = [[]] * batch_size
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
+ fbanks, _ = self.frontend.get_fbank()
+ else:
+ raise Exception("Need to extract feats first, please configure frontend configuration")
+ if feats.shape[0]:
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ waveforms = self.frontend.get_waveforms()
+
+ batch = {
+ "feats": feats,
+ "waveform": waveforms,
+ "in_cache": in_cache,
+ "is_final": is_final,
+ "max_end_sil": max_end_sil
+ }
+ # a. To device
+ batch = to_device(batch, device=self.device)
+ segments, in_cache = self.vad_model.forward_online(**batch)
+ # in_cache.update(batch['in_cache'])
+ # in_cache = {key: value for key, value in batch['in_cache'].items()}
+ return fbanks, segments, in_cache
+
+
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
deleted file mode 100644
index 5fbd844..0000000
--- a/funasr/bin/vad_inference.py
+++ /dev/null
@@ -1,570 +0,0 @@
-import argparse
-import logging
-import os
-import sys
-import json
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import math
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-global_asr_language: str = 'zh-cn'
-global_sample_rate: Union[int, Dict[Any, int]] = {
- 'audio_fs': 16000,
- 'model_fs': 16000
-}
-
-
-class Speech2VadSegment:
- """Speech2VadSegment class
-
- Examples:
- >>> import soundfile
- >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
-
- """
-
- def __init__(
- self,
- vad_infer_config: Union[Path, str] = None,
- vad_model_file: Union[Path, str] = None,
- vad_cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- batch_size: int = 1,
- dtype: str = "float32",
- **kwargs,
- ):
- assert check_argument_types()
-
- # 1. Build vad model
- vad_model, vad_infer_args = VADTask.build_model_from_file(
- vad_infer_config, vad_model_file, device
- )
- frontend = None
- if vad_infer_args.frontend is not None:
- frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
-
- logging.info("vad_model: {}".format(vad_model))
- logging.info("vad_infer_args: {}".format(vad_infer_args))
- vad_model.to(dtype=getattr(torch, dtype)).eval()
-
- self.vad_model = vad_model
- self.vad_infer_args = vad_infer_args
- self.device = device
- self.dtype = dtype
- self.frontend = frontend
- self.batch_size = batch_size
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict()
- ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- self.frontend.filter_length_max = math.inf
- fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
- feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
- fbanks = to_device(fbanks, device=self.device)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
-
- # b. Forward Encoder streaming
- t_offset = 0
- step = min(feats_len.max(), 6000)
- segments = [[]] * self.batch_size
- for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
- if t_offset + step >= feats_len - 1:
- step = feats_len - t_offset
- is_final = True
- else:
- is_final = False
- batch = {
- "feats": feats[:, t_offset:t_offset + step, :],
- "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
- "is_final": is_final,
- "in_cache": in_cache
- }
- # a. To device
- #batch = to_device(batch, device=self.device)
- segments_part, in_cache = self.vad_model(**batch)
- if segments_part:
- for batch_num in range(0, self.batch_size):
- segments[batch_num] += segments_part[batch_num]
- return fbanks, segments
-
-class Speech2VadSegmentOnline(Speech2VadSegment):
- """Speech2VadSegmentOnline class
-
- Examples:
- >>> import soundfile
- >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
-
- """
- def __init__(self, **kwargs):
- super(Speech2VadSegmentOnline, self).__init__(**kwargs)
- vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
- self.frontend = None
- if self.vad_infer_args.frontend is not None:
- self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
-
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- batch_size = speech.shape[0]
- segments = [[]] * batch_size
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
- fbanks, _ = self.frontend.get_fbank()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
- if feats.shape[0]:
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- waveforms = self.frontend.get_waveforms()
-
- batch = {
- "feats": feats,
- "waveform": waveforms,
- "in_cache": in_cache,
- "is_final": is_final,
- "max_end_sil": max_end_sil
- }
- # a. To device
- batch = to_device(batch, device=self.device)
- segments, in_cache = self.vad_model.forward_online(**batch)
- # in_cache.update(batch['in_cache'])
- # in_cache = {key: value for key, value in batch['in_cache'].items()}
- return fbanks, segments, in_cache
-
-
-def inference(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- online: bool = False,
- **kwargs,
-):
- if not online:
- inference_pipeline = inference_modelscope(
- batch_size=batch_size,
- ngpu=ngpu,
- log_level=log_level,
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- key_file=key_file,
- allow_variable_data_keys=allow_variable_data_keys,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- num_workers=num_workers,
- **kwargs,
- )
- else:
- inference_pipeline = inference_modelscope_online(
- batch_size=batch_size,
- ngpu=ngpu,
- log_level=log_level,
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- key_file=key_file,
- allow_variable_data_keys=allow_variable_data_keys,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- num_workers=num_workers,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-def inference_modelscope(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- **kwargs,
-):
- assert check_argument_types()
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
- else:
- writer = None
- ibest_writer = None
-
- vad_results = []
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-
- # do vad segment
- _, results = speech2vadsegment(**batch)
- for i, _ in enumerate(keys):
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
- item = {'key': keys[i], 'value': results[i]}
- vad_results.append(item)
- if writer is not None:
- ibest_writer["text"][keys[i]] = "{}".format(results[i])
-
- return vad_results
-
- return _forward
-
-def inference_modelscope_online(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- **kwargs,
-):
- assert check_argument_types()
-
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
-
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
- else:
- writer = None
- ibest_writer = None
-
- vad_results = []
- batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
- is_final = param_dict.get('is_final', False) if param_dict is not None else False
- max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch['in_cache'] = batch_in_cache
- batch['is_final'] = is_final
- batch['max_end_sil'] = max_end_sil
-
- # do vad segment
- _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
- # param_dict['in_cache'] = batch['in_cache']
- if results:
- for i, _ in enumerate(keys):
- if results[i]:
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
- item = {'key': keys[i], 'value': results[i]}
- vad_results.append(item)
- if writer is not None:
- ibest_writer["text"][keys[i]] = "{}".format(results[i])
-
- return vad_results
-
- return _forward
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="VAD Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--vad_cmvn_file",
- type=str,
- help="Global cmvn file",
- )
- group.add_argument(
- "--online",
- type=str,
- help="decoding mode",
- )
-
- group = parser.add_argument_group("infer related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
-
diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index de58925..b17d058 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,6 +1,8 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
torch.set_num_threads(1)
@@ -17,6 +19,270 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+import argparse
+import logging
+import os
+import sys
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import math
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.vad import VADTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
+
+def inference_vad(
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ vad_infer_config: Optional[str],
+ vad_model_file: Optional[str],
+ vad_cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ batch_size = 1
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2vadsegment
+ speech2vadsegment_kwargs = dict(
+ vad_infer_config=vad_infer_config,
+ vad_model_file=vad_model_file,
+ vad_cmvn_file=vad_cmvn_file,
+ device=device,
+ dtype=dtype,
+ )
+ logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
+ speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = VADTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
+ collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ ibest_writer = writer[f"1best_recog"]
+ else:
+ writer = None
+ ibest_writer = None
+
+ vad_results = []
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+ # do vad segment
+ _, results = speech2vadsegment(**batch)
+ for i, _ in enumerate(keys):
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ results[i] = json.dumps(results[i])
+ item = {'key': keys[i], 'value': results[i]}
+ vad_results.append(item)
+ if writer is not None:
+ ibest_writer["text"][keys[i]] = "{}".format(results[i])
+
+ return vad_results
+
+ return _forward
+
+def inference_vad_online(
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ vad_infer_config: Optional[str],
+ vad_model_file: Optional[str],
+ vad_cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ **kwargs,
+):
+ assert check_argument_types()
+
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+ batch_size = 1
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2vadsegment
+ speech2vadsegment_kwargs = dict(
+ vad_infer_config=vad_infer_config,
+ vad_model_file=vad_model_file,
+ vad_cmvn_file=vad_cmvn_file,
+ device=device,
+ dtype=dtype,
+ )
+ logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
+ speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = VADTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
+ collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ ibest_writer = writer[f"1best_recog"]
+ else:
+ writer = None
+ ibest_writer = None
+
+ vad_results = []
+ if param_dict is None:
+ param_dict = dict()
+ param_dict['in_cache'] = dict()
+ param_dict['is_final'] = True
+ batch_in_cache = param_dict.get('in_cache', dict())
+ is_final = param_dict.get('is_final', False)
+ max_end_sil = param_dict.get('max_end_sil', 800)
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch['in_cache'] = batch_in_cache
+ batch['is_final'] = is_final
+ batch['max_end_sil'] = max_end_sil
+
+ # do vad segment
+ _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
+ # param_dict['in_cache'] = batch['in_cache']
+ if results:
+ for i, _ in enumerate(keys):
+ if results[i]:
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ results[i] = json.dumps(results[i])
+ item = {'key': keys[i], 'value': results[i]}
+ vad_results.append(item)
+ if writer is not None:
+ ibest_writer["text"][keys[i]] = "{}".format(results[i])
+
+ return vad_results
+
+ return _forward
+
+
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "offline":
+ return inference_vad(**kwargs)
+ elif mode == "online":
+ return inference_vad_online(**kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
def get_parser():
parser = config_argparse.ArgumentParser(
@@ -109,17 +375,6 @@
return parser
-def inference_launch(mode, **kwargs):
- if mode == "offline":
- from funasr.bin.vad_inference import inference_modelscope
- return inference_modelscope(**kwargs)
- elif mode == "online":
- from funasr.bin.vad_inference import inference_modelscope_online
- return inference_modelscope_online(**kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
-
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
@@ -147,8 +402,8 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
- inference_launch(**kwargs)
-
+ inference_pipeline = inference_launch(**kwargs)
+ return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":
main()
diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
deleted file mode 100644
index a363309..0000000
--- a/funasr/bin/vad_inference_online.py
+++ /dev/null
@@ -1,344 +0,0 @@
-import argparse
-import logging
-import os
-import sys
-import json
-from pathlib import Path
-from typing import Any
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import numpy as np
-import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
-from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.models.frontend.wav_frontend import WavFrontendOnline
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.bin.vad_inference import Speech2VadSegment
-
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
-
-class Speech2VadSegmentOnline(Speech2VadSegment):
- """Speech2VadSegmentOnline class
-
- Examples:
- >>> import soundfile
- >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2segment(audio)
- [[10, 230], [245, 450], ...]
-
- """
- def __init__(self, **kwargs):
- super(Speech2VadSegmentOnline, self).__init__(**kwargs)
- vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
- self.frontend = None
- if self.vad_infer_args.frontend is not None:
- self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
-
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
- batch_size = speech.shape[0]
- segments = [[]] * batch_size
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
- fbanks, _ = self.frontend.get_fbank()
- else:
- raise Exception("Need to extract feats first, please configure frontend configuration")
- if feats.shape[0]:
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- waveforms = self.frontend.get_waveforms()
-
- batch = {
- "feats": feats,
- "waveform": waveforms,
- "in_cache": in_cache,
- "is_final": is_final,
- "max_end_sil": max_end_sil
- }
- # a. To device
- batch = to_device(batch, device=self.device)
- segments, in_cache = self.vad_model.forward_online(**batch)
- # in_cache.update(batch['in_cache'])
- # in_cache = {key: value for key, value in batch['in_cache'].items()}
- return fbanks, segments, in_cache
-
-
-def inference(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- batch_size=batch_size,
- ngpu=ngpu,
- log_level=log_level,
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- key_file=key_file,
- allow_variable_data_keys=allow_variable_data_keys,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- num_workers=num_workers,
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- batch_size: int,
- ngpu: int,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- vad_infer_config: Optional[str],
- vad_model_file: Optional[str],
- vad_cmvn_file: Optional[str] = None,
- # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- key_file: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- num_workers: int = 1,
- **kwargs,
-):
- assert check_argument_types()
- ncpu = kwargs.get("ncpu", 1)
- torch.set_num_threads(ncpu)
-
- if batch_size > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
-
- if ngpu >= 1 and torch.cuda.is_available():
- device = "cuda"
- else:
- device = "cpu"
- batch_size = 1
- # 1. Set random-seed
- set_all_random_seed(seed)
-
- # 2. Build speech2vadsegment
- speech2vadsegment_kwargs = dict(
- vad_infer_config=vad_infer_config,
- vad_model_file=vad_model_file,
- vad_cmvn_file=vad_cmvn_file,
- device=device,
- dtype=dtype,
- )
- logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
- speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
-
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- ):
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = VADTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
- collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- ibest_writer = writer[f"1best_recog"]
- else:
- writer = None
- ibest_writer = None
-
- vad_results = []
- batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
- is_final = param_dict.get('is_final', False) if param_dict is not None else False
- max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch['in_cache'] = batch_in_cache
- batch['is_final'] = is_final
- batch['max_end_sil'] = max_end_sil
-
- # do vad segment
- _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
- # param_dict['in_cache'] = batch['in_cache']
- if results:
- for i, _ in enumerate(keys):
- if results[i]:
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
- item = {'key': keys[i], 'value': results[i]}
- vad_results.append(item)
- if writer is not None:
- ibest_writer["text"][keys[i]] = "{}".format(results[i])
-
- return vad_results
-
- return _forward
-
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="VAD Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=False)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=False,
- action="append",
- )
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
- group.add_argument("--key_file", type=str_or_none)
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--vad_cmvn_file",
- type=str,
- help="Global cmvn file",
- )
-
- group = parser.add_argument_group("infer related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
-
- return parser
-
-
-def main(cmd=None):
- print(get_commandline_args(), file=sys.stderr)
- parser = get_parser()
- args = parser.parse_args(cmd)
- kwargs = vars(args)
- kwargs.pop("config", None)
- inference(**kwargs)
-
-
-if __name__ == "__main__":
- main()
diff --git a/funasr/lm/__init__.py b/funasr/build_utils/__init__.py
similarity index 100%
rename from funasr/lm/__init__.py
rename to funasr/build_utils/__init__.py
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
new file mode 100644
index 0000000..517c85b
--- /dev/null
+++ b/funasr/build_utils/build_args.py
@@ -0,0 +1,93 @@
+from funasr.models.ctc import CTC
+from funasr.utils import config_argparse
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def build_args(args, parser, extra_task_params):
+ task_parser = config_argparse.ArgumentParser("Task related config")
+ if args.task_name == "asr":
+ from funasr.build_utils.build_asr_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ task_parser.add_argument(
+ "--seg_dict_file",
+ type=str,
+ default=None,
+ help="seg_dict_file for text processing",
+ )
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+ task_parser.add_argument(
+ "--ctc_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(CTC),
+ help="The keyword arguments for CTC class.",
+ )
+ task_parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+
+ elif args.task_name == "pretrain":
+ from funasr.build_utils.build_pretrain_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ elif args.task_name == "lm":
+ from funasr.build_utils.build_lm_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+
+ elif args.task_name == "punc":
+ from funasr.build_utils.build_punc_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+
+ elif args.task_name == "vad":
+ from funasr.build_utils.build_vad_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ elif args.task_name == "diar":
+ from funasr.build_utils.build_diar_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+
+ else:
+ raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+ for action in parser._actions:
+ if not any(action.dest == a.dest for a in task_parser._actions):
+ task_parser._add_action(action)
+
+ task_parser.set_defaults(**vars(args))
+ task_args = task_parser.parse_args(extra_task_params)
+ return task_args
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
new file mode 100644
index 0000000..ddc827f
--- /dev/null
+++ b/funasr/build_utils/build_asr_model.py
@@ -0,0 +1,423 @@
+import logging
+
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+ DynamicConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolutionTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ multichannelfrontend=MultiChannelFrontend,
+ ),
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ asr=ASRModel,
+ uniasr=UniASR,
+ paraformer=Paraformer,
+ paraformer_bert=ParaformerBert,
+ bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
+ mfcca=MFCCA,
+ timestamp_prediction=TimestampPredictor,
+ rnnt=TransducerModel,
+ rnnt_unified=UnifiedTransducerModel,
+ ),
+ default="asr",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ chunk_conformer=ConformerChunkEncoder,
+ ),
+ default="rnn",
+)
+encoder_choices2 = ClassChoices(
+ "encoder2",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ ),
+ default="rnn",
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
+ ),
+ default="rnn",
+)
+decoder_choices2 = ClassChoices(
+ "decoder2",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="rnn",
+)
+predictor_choices = ClassChoices(
+ name="predictor",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ cif_predictor_v3=CifPredictorV3,
+ ),
+ default="cif_predictor",
+ optional=True,
+)
+predictor_choices2 = ClassChoices(
+ name="predictor2",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ ),
+ default="cif_predictor",
+ optional=True,
+)
+stride_conv_choices = ClassChoices(
+ name="stride_conv",
+ classes=dict(
+ stride_conv1d=Conv1dSubsampling
+ ),
+ default="stride_conv1d",
+ optional=True,
+)
+rnnt_decoder_choices = ClassChoices(
+ name="rnnt_decoder",
+ classes=dict(
+ rnnt=RNNTDecoder,
+ ),
+ default="rnnt",
+ optional=True,
+)
+joint_network_choices = ClassChoices(
+ name="joint_network",
+ classes=dict(
+ joint_network=JointNetwork,
+ ),
+ default="joint_network",
+ optional=True,
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ # --predictor and --predictor_conf
+ predictor_choices,
+ # --encoder2 and --encoder2_conf
+ encoder_choices2,
+ # --decoder2 and --decoder2_conf
+ decoder_choices2,
+ # --predictor2 and --predictor2_conf
+ predictor_choices2,
+ # --stride_conv and --stride_conv_conf
+ stride_conv_choices,
+ # --rnnt_decoder and --rnnt_decoder_conf
+ rnnt_decoder_choices,
+ # --joint_network and --joint_network_conf
+ joint_network_choices,
+]
+
+
+def build_asr_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # ctc
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
+ )
+
+ if args.model in ["asr", "mfcca"]:
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+ elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+ # predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ predictor=predictor,
+ **args.model_conf,
+ )
+ elif args.model == "uniasr":
+ # stride_conv
+ stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
+ stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
+ odim=input_size + encoder.output_size())
+ stride_conv_output_size = stride_conv.output_size()
+
+ # encoder2
+ encoder_class2 = encoder_choices2.get_class(args.encoder2)
+ encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
+
+ # decoder2
+ decoder_class2 = decoder_choices2.get_class(args.decoder2)
+ decoder2 = decoder_class2(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder2.output_size(),
+ **args.decoder2_conf,
+ )
+
+ # ctc2
+ ctc2 = CTC(
+ odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
+ )
+
+ # predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ # predictor2
+ predictor_class = predictor_choices2.get_class(args.predictor2)
+ predictor2 = predictor_class(**args.predictor2_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ predictor=predictor,
+ ctc2=ctc2,
+ encoder2=encoder2,
+ decoder2=decoder2,
+ predictor2=predictor2,
+ stride_conv=stride_conv,
+ **args.model_conf,
+ )
+ elif args.model == "timestamp_prediction":
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
+ elif args.model == "rnnt" or args.model == "rnnt_unified":
+ # 5. Decoder
+ encoder_output_size = encoder.output_size()
+
+ rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+ decoder = rnnt_decoder_class(
+ vocab_size,
+ **args.rnnt_decoder_conf,
+ )
+ decoder_output_size = decoder.output_size
+
+ if getattr(args, "decoder", None) is not None:
+ att_decoder_class = decoder_choices.get_class(args.decoder)
+
+ att_decoder = att_decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+ else:
+ att_decoder = None
+ # 6. Joint Network
+ joint_network = JointNetwork(
+ vocab_size,
+ encoder_output_size,
+ decoder_output_size,
+ **args.joint_network_conf,
+ )
+
+ model_class = model_choices.get_class(args.model)
+ # 7. Build model
+ model = model_class(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
new file mode 100644
index 0000000..c95c40d
--- /dev/null
+++ b/funasr/build_utils/build_dataloader.py
@@ -0,0 +1,15 @@
+from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
+
+
+def build_dataloader(args):
+ if args.dataset_type == "small":
+ train_iter_factory = SequenceIterFactory(args, mode="train")
+ valid_iter_factory = SequenceIterFactory(args, mode="valid")
+ elif args.dataset_type == "large":
+ train_iter_factory = LargeDataLoader(args, mode="train")
+ valid_iter_factory = LargeDataLoader(args, mode="valid")
+ else:
+ raise ValueError(f"Not supported dataset_type={args.dataset_type}")
+
+ return train_iter_factory, valid_iter_factory
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
new file mode 100644
index 0000000..6406404
--- /dev/null
+++ b/funasr/build_utils/build_diar_model.py
@@ -0,0 +1,296 @@
+import logging
+
+import torch
+
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.label_aggregation import LabelAggregate
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
+from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
+from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
+from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
+from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ wav_frontend_mel23=WavFrontendMel23,
+ ),
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ default=None,
+ optional=True,
+)
+label_aggregator_choices = ClassChoices(
+ "label_aggregator",
+ classes=dict(
+ label_aggregator=LabelAggregate
+ ),
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ sond=DiarSondModel,
+ eend_ola=DiarEENDOLAModel,
+ ),
+ default="sond",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ san=SelfAttentionEncoder,
+ fsmn=FsmnEncoder,
+ conv=ConvEncoder,
+ resnet34=ResNet34Diar,
+ resnet34_sp_l2reg=ResNet34SpL2RegDiar,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ ecapa_tdnn=ECAPA_TDNN,
+ eend_ola_transformer=EENDOLATransformerEncoder,
+ ),
+ default="resnet34",
+)
+speaker_encoder_choices = ClassChoices(
+ "speaker_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ san=SelfAttentionEncoder,
+ fsmn=FsmnEncoder,
+ conv=ConvEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ default=None,
+ optional=True
+)
+cd_scorer_choices = ClassChoices(
+ "cd_scorer",
+ classes=dict(
+ san=SelfAttentionEncoder,
+ ),
+ default=None,
+ optional=True,
+)
+ci_scorer_choices = ClassChoices(
+ "ci_scorer",
+ classes=dict(
+ dot=DotScorer,
+ cosine=CosScorer,
+ conv=ConvEncoder,
+ ),
+ type_check=torch.nn.Module,
+ default=None,
+ optional=True,
+)
+# decoder is used for output (e.g. post_net in SOND)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ rnn=RNNEncoder,
+ fsmn=FsmnEncoder,
+ ),
+ type_check=torch.nn.Module,
+ default="fsmn",
+)
+# encoder_decoder_attractor is used for EEND-OLA
+encoder_decoder_attractor_choices = ClassChoices(
+ "encoder_decoder_attractor",
+ classes=dict(
+ eda=EncoderDecoderAttractor,
+ ),
+ type_check=torch.nn.Module,
+ default="eda",
+)
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --label_aggregator and --label_aggregator_conf
+ label_aggregator_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --speaker_encoder and --speaker_encoder_conf
+ speaker_encoder_choices,
+ # --cd_scorer and cd_scorer_conf
+ cd_scorer_choices,
+ # --ci_scorer and ci_scorer_conf
+ ci_scorer_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ # --eda and --eda_conf
+ encoder_decoder_attractor_choices,
+]
+
+
+def build_diar_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ if args.model_name == "sond":
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # speaker encoder
+ if getattr(args, "speaker_encoder", None) is not None:
+ speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
+ speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
+ else:
+ speaker_encoder = None
+
+ # ci scorer
+ if getattr(args, "ci_scorer", None) is not None:
+ ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
+ ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
+ else:
+ ci_scorer = None
+
+ # cd scorer
+ if getattr(args, "cd_scorer", None) is not None:
+ cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
+ cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
+ else:
+ cd_scorer = None
+
+ # decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # logger aggregator
+ if getattr(args, "label_aggregator", None) is not None:
+ label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
+ label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
+ else:
+ label_aggregator = None
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ label_aggregator=label_aggregator,
+ encoder=encoder,
+ speaker_encoder=speaker_encoder,
+ ci_scorer=ci_scorer,
+ cd_scorer=cd_scorer,
+ decoder=decoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ elif args.model_name == "eend_ola":
+ # encoder-decoder attractor
+ encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
+ encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
+
+ # 9. Build model
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ encoder_decoder_attractor=encoder_decoder_attractor,
+ **args.model_conf,
+ )
+
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_distributed.py b/funasr/build_utils/build_distributed.py
new file mode 100644
index 0000000..b64b4c0
--- /dev/null
+++ b/funasr/build_utils/build_distributed.py
@@ -0,0 +1,38 @@
+import logging
+import os
+
+import torch
+
+from funasr.train.distributed_utils import DistributedOption
+from funasr.utils.build_dataclass import build_dataclass
+
+
+def build_distributed(args):
+ distributed_option = build_dataclass(DistributedOption, args)
+ if args.use_pai:
+ distributed_option.init_options_pai()
+ distributed_option.init_torch_distributed_pai(args)
+ elif not args.simple_ddp:
+ distributed_option.init_torch_distributed(args)
+ elif args.distributed and args.simple_ddp:
+ distributed_option.init_torch_distributed_pai(args)
+ args.ngpu = torch.distributed.get_world_size()
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ logging.basicConfig(
+ level="INFO",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ else:
+ logging.basicConfig(
+ level="ERROR",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+ distributed_option.dist_rank,
+ distributed_option.local_rank))
+ return distributed_option
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
new file mode 100644
index 0000000..8f4a958
--- /dev/null
+++ b/funasr/build_utils/build_lm_model.py
@@ -0,0 +1,57 @@
+import logging
+
+from funasr.train.abs_model import AbsLM
+from funasr.train.abs_model import LanguageModel
+from funasr.models.seq_rnn_lm import SequentialRNNLM
+from funasr.models.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+ "lm",
+ classes=dict(
+ seq_rnn=SequentialRNNLM,
+ transformer=TransformerLM,
+ ),
+ type_check=AbsLM,
+ default="seq_rnn",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ lm=LanguageModel,
+ ),
+ default="lm",
+)
+
+class_choices_list = [
+ # --lm and --lm_conf
+ lm_choices,
+ # --model and --model_conf
+ model_choices
+]
+
+
+def build_lm_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # lm
+ lm_class = lm_choices.get_class(args.lm)
+ lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
new file mode 100644
index 0000000..13a6faa
--- /dev/null
+++ b/funasr/build_utils/build_model.py
@@ -0,0 +1,25 @@
+from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_lm_model import build_lm_model
+from funasr.build_utils.build_pretrain_model import build_pretrain_model
+from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_vad_model import build_vad_model
+from funasr.build_utils.build_diar_model import build_diar_model
+
+
+def build_model(args):
+ if args.task_name == "asr":
+ model = build_asr_model(args)
+ elif args.task_name == "pretrain":
+ model = build_pretrain_model(args)
+ elif args.task_name == "lm":
+ model = build_lm_model(args)
+ elif args.task_name == "punc":
+ model = build_punc_model(args)
+ elif args.task_name == "vad":
+ model = build_vad_model(args)
+ elif args.task_name == "diar":
+ model = build_diar_model(args)
+ else:
+ raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+ return model
diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py
new file mode 100644
index 0000000..bd0b73d
--- /dev/null
+++ b/funasr/build_utils/build_optimizer.py
@@ -0,0 +1,28 @@
+import torch
+
+from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
+
+
+def build_optimizer(args, model):
+ optim_classes = dict(
+ adam=torch.optim.Adam,
+ fairseq_adam=FairseqAdam,
+ adamw=torch.optim.AdamW,
+ sgd=SGD,
+ adadelta=torch.optim.Adadelta,
+ adagrad=torch.optim.Adagrad,
+ adamax=torch.optim.Adamax,
+ asgd=torch.optim.ASGD,
+ lbfgs=torch.optim.LBFGS,
+ rmsprop=torch.optim.RMSprop,
+ rprop=torch.optim.Rprop,
+ )
+
+ optim_class = optim_classes.get(args.optim)
+ if optim_class is None:
+ raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+ optimizer = optim_class(model.parameters(), **args.optim_conf)
+
+ optimizers = [optimizer]
+ return optimizers
\ No newline at end of file
diff --git a/funasr/build_utils/build_pretrain_model.py b/funasr/build_utils/build_pretrain_model.py
new file mode 100644
index 0000000..629937f
--- /dev/null
+++ b/funasr/build_utils/build_pretrain_model.py
@@ -0,0 +1,107 @@
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(specaug=SpecAug),
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ data2vec=Data2VecPretrainModel,
+ ),
+ default="data2vec",
+)
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+]
+
+
+def build_pretrain_model(args):
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(
+ input_size=input_size,
+ **args.encoder_conf,
+ )
+
+ if args.model == "data2vec":
+ model_class = model_choices.get_class("data2vec")
+ model = model_class(
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ )
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
new file mode 100644
index 0000000..62ccaf2
--- /dev/null
+++ b/funasr/build_utils/build_punc_model.py
@@ -0,0 +1,68 @@
+import logging
+
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_model import PunctuationModel
+from funasr.train.class_choices import ClassChoices
+
+punc_choices = ClassChoices(
+ "punctuation",
+ classes=dict(
+ target_delay=TargetDelayTransformer,
+ vad_realtime=VadRealtimeTransformer
+ ),
+ default="target_delay",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ punc=PunctuationModel,
+ ),
+ default="punc",
+)
+class_choices_list = [
+ # --punc and --punc_conf
+ punc_choices,
+ # --model and --model_conf
+ model_choices
+]
+
+
+def build_punc_model(args):
+ # token_list and punc list
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = token_list.copy()
+ if isinstance(args.punc_list, str):
+ with open(args.punc_list, encoding="utf-8") as f2:
+ pairs = [line.rstrip().split(":") for line in f2]
+ punc_list = [pair[0] for pair in pairs]
+ punc_weight_list = [float(pair[1]) for pair in pairs]
+ args.punc_list = punc_list.copy()
+ elif isinstance(args.punc_list, list):
+ punc_list = args.punc_list.copy()
+ punc_weight_list = [1] * len(punc_list)
+ if isinstance(args.token_list, (tuple, list)):
+ token_list = args.token_list.copy()
+ else:
+ raise RuntimeError("token_list must be str or dict")
+
+ vocab_size = len(token_list)
+ punc_size = len(punc_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # punc
+ punc_class = punc_choices.get_class(args.punctuation)
+ punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
+
+ if "punc_weight" in args.model_conf:
+ args.model_conf.pop("punc_weight")
+ model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
new file mode 100644
index 0000000..4b9990e
--- /dev/null
+++ b/funasr/build_utils/build_scheduler.py
@@ -0,0 +1,44 @@
+import torch
+import torch.multiprocessing
+import torch.nn
+import torch.optim
+
+from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
+
+
+def build_scheduler(args, optimizers):
+ scheduler_classes = dict(
+ ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+ lambdalr=torch.optim.lr_scheduler.LambdaLR,
+ steplr=torch.optim.lr_scheduler.StepLR,
+ multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+ exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+ CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+ noamlr=NoamLR,
+ warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
+ cycliclr=torch.optim.lr_scheduler.CyclicLR,
+ onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+ CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+ )
+
+ schedulers = []
+ for i, optim in enumerate(optimizers, 1):
+ suf = "" if i == 1 else str(i)
+ name = getattr(args, f"scheduler{suf}")
+ conf = getattr(args, f"scheduler{suf}_conf")
+ if name is not None:
+ cls_ = scheduler_classes.get(name)
+ if cls_ is None:
+ raise ValueError(
+ f"must be one of {list(scheduler_classes)}: {name}"
+ )
+ scheduler = cls_(optim, **conf)
+ else:
+ scheduler = None
+
+ schedulers.append(scheduler)
+
+ return schedulers
\ No newline at end of file
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
new file mode 100644
index 0000000..aff99b5
--- /dev/null
+++ b/funasr/build_utils/build_trainer.py
@@ -0,0 +1,820 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Trainer module."""
+import argparse
+import dataclasses
+import logging
+import os
+import time
+from contextlib import contextmanager
+from dataclasses import is_dataclass
+from distutils.version import LooseVersion
+from io import BytesIO
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import humanfriendly
+import oss2
+import torch
+import torch.nn
+import torch.optim
+from typeguard import check_argument_types
+
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.main_funcs.average_nbest_models import average_nbest_models
+from funasr.models.base_model import FunASRModel
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
+from funasr.schedulers.abs_scheduler import AbsScheduler
+from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
+from funasr.torch_utils.add_gradient_noise import add_gradient_noise
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.recursive_op import recursive_average
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.train.distributed_utils import DistributedOption
+from funasr.train.reporter import Reporter
+from funasr.train.reporter import SubReporter
+from funasr.utils.build_dataclass import build_dataclass
+
+if torch.distributed.is_available():
+ from torch.distributed import ReduceOp
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+ from torch.cuda.amp import GradScaler
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+ GradScaler = None
+
+try:
+ import fairscale
+except ImportError:
+ fairscale = None
+
+
+@dataclasses.dataclass
+class TrainerOptions:
+ ngpu: int
+ resume: bool
+ use_amp: bool
+ train_dtype: str
+ grad_noise: bool
+ accum_grad: int
+ grad_clip: float
+ grad_clip_type: float
+ log_interval: Optional[int]
+ # no_forward_run: bool
+ use_tensorboard: bool
+ # use_wandb: bool
+ output_dir: Union[Path, str]
+ max_epoch: int
+ max_update: int
+ seed: int
+ # sharded_ddp: bool
+ patience: Optional[int]
+ keep_nbest_models: Union[int, List[int]]
+ nbest_averaging_interval: int
+ early_stopping_criterion: Sequence[str]
+ best_model_criterion: Sequence[Sequence[str]]
+ val_scheduler_criterion: Sequence[str]
+ unused_parameters: bool
+ # wandb_model_log_interval: int
+ use_pai: bool
+ oss_bucket: Union[oss2.Bucket, None]
+
+
+class Trainer:
+ """Trainer
+
+ """
+
+ def __init__(self,
+ args,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_dataloader: AbsIterFactory,
+ valid_dataloader: AbsIterFactory,
+ distributed_option: DistributedOption):
+ self.trainer_options = self.build_options(args)
+ self.model = model
+ self.optimizers = optimizers
+ self.schedulers = schedulers
+ self.train_dataloader = train_dataloader
+ self.valid_dataloader = valid_dataloader
+ self.distributed_option = distributed_option
+
+ def build_options(self, args: argparse.Namespace) -> TrainerOptions:
+ """Build options consumed by train(), eval()"""
+ assert check_argument_types()
+ return build_dataclass(TrainerOptions, args)
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ """Reserved for future development of another Trainer"""
+ pass
+
+ def resume(self,
+ checkpoint: Union[str, Path],
+ model: torch.nn.Module,
+ reporter: Reporter,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ ngpu: int = 0,
+ ):
+ states = torch.load(
+ checkpoint,
+ map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
+ )
+ model.load_state_dict(states["model"])
+ reporter.load_state_dict(states["reporter"])
+ for optimizer, state in zip(optimizers, states["optimizers"]):
+ optimizer.load_state_dict(state)
+ for scheduler, state in zip(schedulers, states["schedulers"]):
+ if scheduler is not None:
+ scheduler.load_state_dict(state)
+ if scaler is not None:
+ if states["scaler"] is None:
+ logging.warning("scaler state is not found")
+ else:
+ scaler.load_state_dict(states["scaler"])
+
+ logging.info(f"The training was resumed using {checkpoint}")
+
+ def run(self) -> None:
+ """Perform training. This method performs the main process of training."""
+ assert check_argument_types()
+ # NOTE(kamo): Don't check the type more strictly as far trainer_options
+ model = self.model
+ optimizers = self.optimizers
+ schedulers = self.schedulers
+ train_dataloader = self.train_dataloader
+ valid_dataloader = self.valid_dataloader
+ trainer_options = self.trainer_options
+ distributed_option = self.distributed_option
+ assert is_dataclass(trainer_options), type(trainer_options)
+ assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
+
+ if isinstance(trainer_options.keep_nbest_models, int):
+ keep_nbest_models = [trainer_options.keep_nbest_models]
+ else:
+ if len(trainer_options.keep_nbest_models) == 0:
+ logging.warning("No keep_nbest_models is given. Change to [1]")
+ trainer_options.keep_nbest_models = [1]
+ keep_nbest_models = trainer_options.keep_nbest_models
+
+ output_dir = Path(trainer_options.output_dir)
+ reporter = Reporter()
+ if trainer_options.use_amp:
+ if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
+ raise RuntimeError(
+ "Require torch>=1.6.0 for Automatic Mixed Precision"
+ )
+ # if trainer_options.sharded_ddp:
+ # if fairscale is None:
+ # raise RuntimeError(
+ # "Requiring fairscale. Do 'pip install fairscale'"
+ # )
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ # else:
+ scaler = GradScaler()
+ else:
+ scaler = None
+
+ if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
+ self.resume(
+ checkpoint=output_dir / "checkpoint.pb",
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ reporter=reporter,
+ scaler=scaler,
+ ngpu=trainer_options.ngpu,
+ )
+
+ start_epoch = reporter.get_epoch() + 1
+ if start_epoch == trainer_options.max_epoch + 1:
+ logging.warning(
+ f"The training has already reached at max_epoch: {start_epoch}"
+ )
+
+ if distributed_option.distributed:
+ dp_model = torch.nn.parallel.DistributedDataParallel(
+ model, find_unused_parameters=trainer_options.unused_parameters)
+ elif distributed_option.ngpu > 1:
+ dp_model = torch.nn.parallel.DataParallel(
+ model,
+ device_ids=list(range(distributed_option.ngpu)),
+ )
+ else:
+ # NOTE(kamo): DataParallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ dp_model = model
+
+ if trainer_options.use_tensorboard and (
+ not distributed_option.distributed or distributed_option.dist_rank == 0
+ ):
+ from torch.utils.tensorboard import SummaryWriter
+ if trainer_options.use_pai:
+ train_summary_writer = SummaryWriter(
+ os.path.join(trainer_options.output_dir, "tensorboard/train")
+ )
+ valid_summary_writer = SummaryWriter(
+ os.path.join(trainer_options.output_dir, "tensorboard/valid")
+ )
+ else:
+ train_summary_writer = SummaryWriter(
+ str(output_dir / "tensorboard" / "train")
+ )
+ valid_summary_writer = SummaryWriter(
+ str(output_dir / "tensorboard" / "valid")
+ )
+ else:
+ train_summary_writer = None
+
+ start_time = time.perf_counter()
+ for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
+ if iepoch != start_epoch:
+ logging.info(
+ "{}/{}epoch started. Estimated time to finish: {}".format(
+ iepoch,
+ trainer_options.max_epoch,
+ humanfriendly.format_timespan(
+ (time.perf_counter() - start_time)
+ / (iepoch - start_epoch)
+ * (trainer_options.max_epoch - iepoch + 1)
+ ),
+ )
+ )
+ else:
+ logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
+ set_all_random_seed(trainer_options.seed + iepoch)
+
+ reporter.set_epoch(iepoch)
+ # 1. Train and validation for one-epoch
+ with reporter.observe("train") as sub_reporter:
+ all_steps_are_invalid, max_update_stop = self.train_one_epoch(
+ model=dp_model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ iterator=train_dataloader.build_iter(iepoch),
+ reporter=sub_reporter,
+ scaler=scaler,
+ summary_writer=train_summary_writer,
+ options=trainer_options,
+ distributed_option=distributed_option,
+ )
+
+ with reporter.observe("valid") as sub_reporter:
+ self.validate_one_epoch(
+ model=dp_model,
+ iterator=valid_dataloader.build_iter(iepoch),
+ reporter=sub_reporter,
+ options=trainer_options,
+ distributed_option=distributed_option,
+ )
+
+ # 2. LR Scheduler step
+ for scheduler in schedulers:
+ if isinstance(scheduler, AbsValEpochStepScheduler):
+ scheduler.step(
+ reporter.get_value(*trainer_options.val_scheduler_criterion)
+ )
+ elif isinstance(scheduler, AbsEpochStepScheduler):
+ scheduler.step()
+ # if trainer_options.sharded_ddp:
+ # for optimizer in optimizers:
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
+ # optimizer.consolidate_state_dict()
+
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ # 3. Report the results
+ logging.info(reporter.log_message())
+ if train_summary_writer is not None:
+ reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
+ reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
+ # if trainer_options.use_wandb:
+ # reporter.wandb_log()
+
+ # save tensorboard on oss
+ if trainer_options.use_pai and train_summary_writer is not None:
+ def write_tensorboard_summary(summary_writer_path, oss_bucket):
+ file_list = []
+ for root, dirs, files in os.walk(summary_writer_path, topdown=False):
+ for name in files:
+ file_full_path = os.path.join(root, name)
+ file_list.append(file_full_path)
+
+ for file_full_path in file_list:
+ with open(file_full_path, "rb") as f:
+ oss_bucket.put_object(file_full_path, f)
+
+ write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
+ trainer_options.oss_bucket)
+ write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
+ trainer_options.oss_bucket)
+
+ # 4. Save/Update the checkpoint
+ if trainer_options.use_pai:
+ buffer = BytesIO()
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "reporter": reporter.state_dict(),
+ "optimizers": [o.state_dict() for o in optimizers],
+ "schedulers": [
+ s.state_dict() if s is not None else None
+ for s in schedulers
+ ],
+ "scaler": scaler.state_dict() if scaler is not None else None,
+ "ema_model": model.encoder.ema.model.state_dict()
+ if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
+ },
+ buffer,
+ )
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
+ buffer.getvalue())
+ else:
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "reporter": reporter.state_dict(),
+ "optimizers": [o.state_dict() for o in optimizers],
+ "schedulers": [
+ s.state_dict() if s is not None else None
+ for s in schedulers
+ ],
+ "scaler": scaler.state_dict() if scaler is not None else None,
+ },
+ output_dir / "checkpoint.pb",
+ )
+
+ # 5. Save and log the model and update the link to the best model
+ if trainer_options.use_pai:
+ buffer = BytesIO()
+ torch.save(model.state_dict(), buffer)
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), buffer.getvalue())
+ else:
+ torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
+
+ # Creates a sym link latest.pb -> {iepoch}epoch.pb
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, "latest.pb")
+ if trainer_options.oss_bucket.object_exists(p):
+ trainer_options.oss_bucket.delete_object(p)
+ trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+ os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), p)
+ else:
+ p = output_dir / "latest.pb"
+ if p.is_symlink() or p.exists():
+ p.unlink()
+ p.symlink_to(f"{iepoch}epoch.pb")
+
+ _improved = []
+ for _phase, k, _mode in trainer_options.best_model_criterion:
+ # e.g. _phase, k, _mode = "train", "loss", "min"
+ if reporter.has(_phase, k):
+ best_epoch = reporter.get_best_epoch(_phase, k, _mode)
+ # Creates sym links if it's the best result
+ if best_epoch == iepoch:
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
+ if trainer_options.oss_bucket.object_exists(p):
+ trainer_options.oss_bucket.delete_object(p)
+ trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+ os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), p)
+ else:
+ p = output_dir / f"{_phase}.{k}.best.pb"
+ if p.is_symlink() or p.exists():
+ p.unlink()
+ p.symlink_to(f"{iepoch}epoch.pb")
+ _improved.append(f"{_phase}.{k}")
+ if len(_improved) == 0:
+ logging.info("There are no improvements in this epoch")
+ else:
+ logging.info(
+ "The best model has been updated: " + ", ".join(_improved)
+ )
+
+ # log_model = (
+ # trainer_options.wandb_model_log_interval > 0
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
+ # )
+ # if log_model and trainer_options.use_wandb:
+ # import wandb
+ #
+ # logging.info("Logging Model on this epoch :::::")
+ # artifact = wandb.Artifact(
+ # name=f"model_{wandb.run.id}",
+ # type="model",
+ # metadata={"improved": _improved},
+ # )
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ # aliases = [
+ # f"epoch-{iepoch}",
+ # "best" if best_epoch == iepoch else "",
+ # ]
+ # wandb.log_artifact(artifact, aliases=aliases)
+
+ # 6. Remove the model files excluding n-best epoch and latest epoch
+ _removed = []
+ # Get the union set of the n-best among multiple criterion
+ nbests = set().union(
+ *[
+ set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
+ for ph, k, m in trainer_options.best_model_criterion
+ if reporter.has(ph, k)
+ ]
+ )
+
+ # Generated n-best averaged model
+ if (
+ trainer_options.nbest_averaging_interval > 0
+ and iepoch % trainer_options.nbest_averaging_interval == 0
+ ):
+ average_nbest_models(
+ reporter=reporter,
+ output_dir=output_dir,
+ best_model_criterion=trainer_options.best_model_criterion,
+ nbest=keep_nbest_models,
+ suffix=f"till{iepoch}epoch",
+ oss_bucket=trainer_options.oss_bucket,
+ pai_output_dir=trainer_options.output_dir,
+ )
+
+ for e in range(1, iepoch):
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
+ if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
+ trainer_options.oss_bucket.delete_object(p)
+ _removed.append(str(p))
+ else:
+ p = output_dir / f"{e}epoch.pb"
+ if p.exists() and e not in nbests:
+ p.unlink()
+ _removed.append(str(p))
+ if len(_removed) != 0:
+ logging.info("The model files were removed: " + ", ".join(_removed))
+
+ # 7. If any updating haven't happened, stops the training
+ if all_steps_are_invalid:
+ logging.warning(
+ f"The gradients at all steps are invalid in this epoch. "
+ f"Something seems wrong. This training was stopped at {iepoch}epoch"
+ )
+ break
+
+ if max_update_stop:
+ logging.info(
+ f"Stopping training due to "
+ f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
+ )
+ break
+
+ # 8. Check early stopping
+ if trainer_options.patience is not None:
+ if reporter.check_early_stopping(
+ trainer_options.patience, *trainer_options.early_stopping_criterion
+ ):
+ break
+
+ else:
+ logging.info(
+ f"The training was finished at {trainer_options.max_epoch} epochs "
+ )
+
+ # Generated n-best averaged model
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ average_nbest_models(
+ reporter=reporter,
+ output_dir=output_dir,
+ best_model_criterion=trainer_options.best_model_criterion,
+ nbest=keep_nbest_models,
+ oss_bucket=trainer_options.oss_bucket,
+ pai_output_dir=trainer_options.output_dir,
+ )
+
+ def train_one_epoch(
+ self,
+ model: torch.nn.Module,
+ iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ reporter: SubReporter,
+ summary_writer,
+ options: TrainerOptions,
+ distributed_option: DistributedOption,
+ ) -> Tuple[bool, bool]:
+ assert check_argument_types()
+
+ grad_noise = options.grad_noise
+ accum_grad = options.accum_grad
+ grad_clip = options.grad_clip
+ grad_clip_type = options.grad_clip_type
+ log_interval = options.log_interval
+ # no_forward_run = options.no_forward_run
+ ngpu = options.ngpu
+ # use_wandb = options.use_wandb
+ distributed = distributed_option.distributed
+
+ if log_interval is None:
+ try:
+ log_interval = max(len(iterator) // 20, 10)
+ except TypeError:
+ log_interval = 100
+
+ model.train()
+ all_steps_are_invalid = True
+ max_update_stop = False
+ # [For distributed] Because iteration counts are not always equals between
+ # processes, send stop-flag to the other processes if iterator is finished
+ iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+
+ start_time = time.perf_counter()
+ for iiter, (_, batch) in enumerate(
+ reporter.measure_iter_time(iterator, "iter_time"), 1
+ ):
+ assert isinstance(batch, dict), type(batch)
+
+ if distributed:
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+
+ batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+ # if no_forward_run:
+ # all_steps_are_invalid = False
+ # continue
+
+ with autocast(scaler is not None):
+ with reporter.measure_time("forward_time"):
+ retval = model(**batch)
+
+ # Note(kamo):
+ # Supporting two patterns for the returned value from the model
+ # a. dict type
+ if isinstance(retval, dict):
+ loss = retval["loss"]
+ stats = retval["stats"]
+ weight = retval["weight"]
+ optim_idx = retval.get("optim_idx")
+ if optim_idx is not None and not isinstance(optim_idx, int):
+ if not isinstance(optim_idx, torch.Tensor):
+ raise RuntimeError(
+ "optim_idx must be int or 1dim torch.Tensor, "
+ f"but got {type(optim_idx)}"
+ )
+ if optim_idx.dim() >= 2:
+ raise RuntimeError(
+ "optim_idx must be int or 1dim torch.Tensor, "
+ f"but got {optim_idx.dim()}dim tensor"
+ )
+ if optim_idx.dim() == 1:
+ for v in optim_idx:
+ if v != optim_idx[0]:
+ raise RuntimeError(
+ "optim_idx must be 1dim tensor "
+ "having same values for all entries"
+ )
+ optim_idx = optim_idx[0].item()
+ else:
+ optim_idx = optim_idx.item()
+
+ # b. tuple or list type
+ else:
+ loss, stats, weight = retval
+ optim_idx = None
+
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if ngpu > 1 or distributed:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed)
+
+ # Now weight is summation over all workers
+ loss /= weight
+ if distributed:
+ # NOTE(kamo): Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= torch.distributed.get_world_size()
+
+ loss /= accum_grad
+
+ reporter.register(stats, weight)
+
+ with reporter.measure_time("backward_time"):
+ if scaler is not None:
+ # Scales loss. Calls backward() on scaled loss
+ # to create scaled gradients.
+ # Backward passes under autocast are not recommended.
+ # Backward ops run in the same dtype autocast chose
+ # for corresponding forward ops.
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ if iiter % accum_grad == 0:
+ if scaler is not None:
+ # Unscales the gradients of optimizer's assigned params in-place
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ scaler.unscale_(optimizer)
+
+ # gradient noise injection
+ if grad_noise:
+ add_gradient_noise(
+ model,
+ reporter.get_total_count(),
+ duration=100,
+ eta=1.0,
+ scale_factor=0.55,
+ )
+
+ # compute the gradient norm to check if it is normal or not
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ model.parameters(),
+ max_norm=grad_clip,
+ norm_type=grad_clip_type,
+ )
+ # PyTorch<=1.4, clip_grad_norm_ returns float value
+ if not isinstance(grad_norm, torch.Tensor):
+ grad_norm = torch.tensor(grad_norm)
+
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+
+ # Must invoke scaler.update() if unscale_() is used in the iteration
+ # to avoid the following error:
+ # RuntimeError: unscale_() has already been called
+ # on this optimizer since the last update().
+ # Note that if the gradient has inf/nan values,
+ # scaler.step skips optimizer.step().
+ if scaler is not None:
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ scaler.step(optimizer)
+ scaler.update()
+
+ else:
+ all_steps_are_invalid = False
+ with reporter.measure_time("optim_step_time"):
+ for iopt, (optimizer, scheduler) in enumerate(
+ zip(optimizers, schedulers)
+ ):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ if scaler is not None:
+ # scaler.step() first unscales the gradients of
+ # the optimizer's assigned params.
+ scaler.step(optimizer)
+ # Updates the scale for next iteration.
+ scaler.update()
+ else:
+ optimizer.step()
+ if isinstance(scheduler, AbsBatchStepScheduler):
+ scheduler.step()
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ optimizer.zero_grad()
+
+ # Register lr and train/load time[sec/step],
+ # where step refers to accum_grad * mini-batch
+ reporter.register(
+ dict(
+ {
+ f"optim{i}_lr{j}": pg["lr"]
+ for i, optimizer in enumerate(optimizers)
+ for j, pg in enumerate(optimizer.param_groups)
+ if "lr" in pg
+ },
+ train_time=time.perf_counter() - start_time,
+ ),
+ )
+ start_time = time.perf_counter()
+
+ # update num_updates
+ if distributed:
+ if hasattr(model.module, "num_updates"):
+ model.module.set_num_updates(model.module.get_num_updates() + 1)
+ options.num_updates = model.module.get_num_updates()
+ if model.module.get_num_updates() >= options.max_update:
+ max_update_stop = True
+ else:
+ if hasattr(model, "num_updates"):
+ model.set_num_updates(model.get_num_updates() + 1)
+ options.num_updates = model.get_num_updates()
+ if model.get_num_updates() >= options.max_update:
+ max_update_stop = True
+
+ # NOTE(kamo): Call log_message() after next()
+ reporter.next()
+ if iiter % log_interval == 0:
+ num_updates = options.num_updates if hasattr(options, "num_updates") else None
+ logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
+ if summary_writer is not None:
+ reporter.tensorboard_add_scalar(summary_writer, -log_interval)
+ # if use_wandb:
+ # reporter.wandb_log()
+
+ if max_update_stop:
+ break
+
+ else:
+ if distributed:
+ iterator_stop.fill_(1)
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ return all_steps_are_invalid, max_update_stop
+
+ @torch.no_grad()
+ def validate_one_epoch(
+ self,
+ model: torch.nn.Module,
+ iterator: Iterable[Dict[str, torch.Tensor]],
+ reporter: SubReporter,
+ options: TrainerOptions,
+ distributed_option: DistributedOption,
+ ) -> None:
+ assert check_argument_types()
+ ngpu = options.ngpu
+ # no_forward_run = options.no_forward_run
+ distributed = distributed_option.distributed
+
+ model.eval()
+
+ # [For distributed] Because iteration counts are not always equals between
+ # processes, send stop-flag to the other processes if iterator is finished
+ iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+ for (_, batch) in iterator:
+ assert isinstance(batch, dict), type(batch)
+ if distributed:
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+
+ batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+ # if no_forward_run:
+ # continue
+
+ retval = model(**batch)
+ if isinstance(retval, dict):
+ stats = retval["stats"]
+ weight = retval["weight"]
+ else:
+ _, stats, weight = retval
+ if ngpu > 1 or distributed:
+ # Apply weighted averaging for stats.
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed)
+
+ reporter.register(stats, weight)
+ reporter.next()
+
+ else:
+ if distributed:
+ iterator_stop.fill_(1)
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+
+
+def build_trainer(
+ args,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_dataloader: AbsIterFactory,
+ valid_dataloader: AbsIterFactory,
+ distributed_option: DistributedOption
+):
+ trainer = Trainer(
+ args=args,
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ train_dataloader=train_dataloader,
+ valid_dataloader=valid_dataloader,
+ distributed_option=distributed_option
+ )
+ return trainer
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
new file mode 100644
index 0000000..76eb09b
--- /dev/null
+++ b/funasr/build_utils/build_vad_model.py
@@ -0,0 +1,77 @@
+import torch
+
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ wav_frontend_online=WavFrontendOnline,
+ ),
+ default="default",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ fsmn=FSMN,
+ ),
+ type_check=torch.nn.Module,
+ default="fsmn",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ e2evad=E2EVadModel,
+ ),
+ default="e2evad",
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+]
+
+
+def build_vad_model(args):
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(**args.encoder_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/datasets/iterable_dataset_modelscope.py b/funasr/datasets/iterable_dataset_modelscope.py
deleted file mode 100644
index 860492c..0000000
--- a/funasr/datasets/iterable_dataset_modelscope.py
+++ /dev/null
@@ -1,349 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-# Part of the implementation is borrowed from espnet/espnet.
-"""Iterable dataset module."""
-import copy
-from io import StringIO
-from pathlib import Path
-from typing import Callable, Collection, Dict, Iterator, Tuple, Union
-
-import kaldiio
-import numpy as np
-import soundfile
-import torch
-from funasr.datasets.dataset import ESPnetDataset
-from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
-
-from funasr.utils import wav_utils
-
-
-def load_kaldi(input):
- retval = kaldiio.load_mat(input)
- if isinstance(retval, tuple):
- assert len(retval) == 2, len(retval)
- if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
- # sound scp case
- rate, array = retval
- elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
- # Extended ark format case
- array, rate = retval
- else:
- raise RuntimeError(
- f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
-
- # Multichannel wave fie
- # array: (NSample, Channel) or (Nsample)
-
- else:
- # Normal ark case
- assert isinstance(retval, np.ndarray), type(retval)
- array = retval
- return array
-
-
-DATA_TYPES = {
- 'sound':
- lambda x: soundfile.read(x)[0],
- 'kaldi_ark':
- load_kaldi,
- 'npy':
- np.load,
- 'text_int':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
- 'csv_int':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
- 'text_float':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
- ),
- 'csv_float':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
- ),
- 'text':
- lambda x: x,
-}
-
-
-class IterableESPnetDatasetModelScope(IterableDataset):
- """Pytorch Dataset class for ESPNet.
-
- Examples:
- >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- >>> for uid, data in dataset:
- ... data
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(self,
- path_name_type_list: Collection[Tuple[any, str, str]],
- preprocess: Callable[[str, Dict[str, np.ndarray]],
- Dict[str, np.ndarray]] = None,
- float_dtype: str = 'float32',
- int_dtype: str = 'long',
- key_file: str = None,
- sample_rate: Union[dict, int] = 16000):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"')
-
- self.preprocess = preprocess
-
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.key_file = key_file
- self.sample_rate = sample_rate
-
- self.debug_info = {}
- non_iterable_list = []
- self.path_name_type_list = []
-
- path_list = path_name_type_list[0]
- name = path_name_type_list[1]
- _type = path_name_type_list[2]
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- self.debug_info[name] = path_list, _type
- # for path, name, _type in path_name_type_list:
- for path in path_list:
- self.path_name_type_list.append((path, name, _type))
-
- if len(non_iterable_list) != 0:
- # Some types doesn't support iterable mode
- self.non_iterable_dataset = ESPnetDataset(
- path_name_type_list=non_iterable_list,
- preprocess=preprocess,
- float_dtype=float_dtype,
- int_dtype=int_dtype,
- )
- else:
- self.non_iterable_dataset = None
-
- self.apply_utt2category = False
-
- def has_name(self, name) -> bool:
- return name in self.debug_info
-
- def names(self) -> Tuple[str, ...]:
- return tuple(self.debug_info)
-
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += '('
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f'\n preprocess: {self.preprocess})'
- return _mes
-
- def __iter__(
- self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
- torch.set_printoptions(profile='default')
- count = len(self.path_name_type_list)
- for idx in range(count):
- # 2. Load the entry from each line and create a dict
- data = {}
- # 2.a. Load data streamingly
-
- # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
- value = self.path_name_type_list[idx][0]['file']
- uid = self.path_name_type_list[idx][0]['key']
- # name: speech
- name = self.path_name_type_list[idx][1]
- _type = self.path_name_type_list[idx][2]
- func = DATA_TYPES[_type]
- array = func(value)
-
- # 2.b. audio resample
- if _type == 'sound':
- audio_sr: int = 16000
- model_sr: int = 16000
- if isinstance(self.sample_rate, int):
- model_sr = self.sample_rate
- else:
- if 'audio_sr' in self.sample_rate:
- audio_sr = self.sample_rate['audio_sr']
- if 'model_sr' in self.sample_rate:
- model_sr = self.sample_rate['model_sr']
- array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
- # array: [ 1.25122070e-03 ... ]
- data[name] = array
-
- # 3. [Option] Apply preprocessing
- # e.g. espnet2.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
- # 4. Force data-precision
- for name in data:
- # value is np.ndarray data
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f'All values must be converted to np.ndarray object '
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
-
- # Cast to desired type
- if value.dtype.kind == 'f':
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == 'i':
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(
- f'Not supported dtype: {value.dtype}')
- data[name] = value
-
- yield uid, data
-
- if count == 0:
- raise RuntimeError('No iteration')
-
-
-class IterableESPnetBytesModelScope(IterableDataset):
- """Pytorch audio bytes class for ESPNet.
-
- Examples:
- >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- >>> for uid, data in dataset:
- ... data
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(self,
- path_name_type_list: Collection[Tuple[any, str, str]],
- preprocess: Callable[[str, Dict[str, np.ndarray]],
- Dict[str, np.ndarray]] = None,
- float_dtype: str = 'float32',
- int_dtype: str = 'long',
- key_file: str = None,
- sample_rate: Union[dict, int] = 16000):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"')
-
- self.preprocess = preprocess
-
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.key_file = key_file
- self.sample_rate = sample_rate
-
- self.debug_info = {}
- non_iterable_list = []
- self.path_name_type_list = []
-
- audio_data = path_name_type_list[0]
- name = path_name_type_list[1]
- _type = path_name_type_list[2]
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- self.debug_info[name] = audio_data, _type
- self.path_name_type_list.append((audio_data, name, _type))
-
- if len(non_iterable_list) != 0:
- # Some types doesn't support iterable mode
- self.non_iterable_dataset = ESPnetDataset(
- path_name_type_list=non_iterable_list,
- preprocess=preprocess,
- float_dtype=float_dtype,
- int_dtype=int_dtype,
- )
- else:
- self.non_iterable_dataset = None
-
- self.apply_utt2category = False
-
- if float_dtype == 'float32':
- self.np_dtype = np.float32
-
- def has_name(self, name) -> bool:
- return name in self.debug_info
-
- def names(self) -> Tuple[str, ...]:
- return tuple(self.debug_info)
-
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += '('
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f'\n preprocess: {self.preprocess})'
- return _mes
-
- def __iter__(
- self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-
- torch.set_printoptions(profile='default')
- # 2. Load the entry from each line and create a dict
- data = {}
- # 2.a. Load data streamingly
-
- value = self.path_name_type_list[0][0]
- uid = 'pcm_data'
- # name: speech
- name = self.path_name_type_list[0][1]
- _type = self.path_name_type_list[0][2]
- func = DATA_TYPES[_type]
- # array: [ 1.25122070e-03 ... ]
- # data[name] = np.frombuffer(value, dtype=self.np_dtype)
-
- # 2.b. byte(PCM16) to float32
- middle_data = np.frombuffer(value, dtype=np.int16)
- middle_data = np.asarray(middle_data)
- if middle_data.dtype.kind not in 'iu':
- raise TypeError("'middle_data' must be an array of integers")
- dtype = np.dtype('float32')
- if dtype.kind != 'f':
- raise TypeError("'dtype' must be a floating point type")
-
- i = np.iinfo(middle_data.dtype)
- abs_max = 2**(i.bits - 1)
- offset = i.min + abs_max
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
- dtype=self.np_dtype)
-
- # 2.c. audio resample
- if _type == 'sound':
- audio_sr: int = 16000
- model_sr: int = 16000
- if isinstance(self.sample_rate, int):
- model_sr = self.sample_rate
- else:
- if 'audio_sr' in self.sample_rate:
- audio_sr = self.sample_rate['audio_sr']
- if 'model_sr' in self.sample_rate:
- model_sr = self.sample_rate['model_sr']
- array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
- data[name] = array
-
- # 3. [Option] Apply preprocessing
- # e.g. espnet2.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
- # 4. Force data-precision
- for name in data:
- # value is np.ndarray data
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f'All values must be converted to np.ndarray object '
- f'by preprocessing, but "{name}" is still {type(value)}.')
-
- # Cast to desired type
- if value.dtype.kind == 'f':
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == 'i':
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(
- f'Not supported dtype: {value.dtype}')
- data[name] = value
-
- yield uid, data
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 156f608..318ae0b 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -64,28 +64,26 @@
return self.sp.DecodePieces(list(tokens))
-class ArkDataLoader(AbsIterFactory):
- def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
- bpemodel_file=None, mode="train"):
- symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
- if seg_dict_file is not None:
- seg_dict = load_seg_dict(seg_dict_file)
- else:
- seg_dict = None
- if punc_dict_file is not None:
- punc_dict = read_symbol_table(punc_dict_file)
- else:
- punc_dict = None
- self.dataset_conf = dataset_conf
- self.frontend_conf = frontend_conf
+class LargeDataLoader(AbsIterFactory):
+ def __init__(self, args, mode="train"):
+ symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
+ if hasattr(args, "token_list") and args.token_list is not None:
+ symbol_table = read_symbol_table(args.token_list)
+ if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
+ seg_dict = load_seg_dict(args.seg_dict_file)
+ if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
+ punc_dict = read_symbol_table(args.punc_dict_file)
+ if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
+ bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
+ self.dataset_conf = args.dataset_conf
+ self.frontend_conf = args.frontend_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
- if bpemodel_file is not None:
- bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
- else:
- bpe_tokenizer = None
+ data_list = args.train_data_file if mode == "train" else args.valid_data_file
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
- self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
+ self.dataset_conf, self.frontend_conf,
+ speed_perturb=args.speed_perturb if mode == "train" else None,
+ mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 8c224d8..5df61fd 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -1,20 +1,20 @@
+import logging
import os
import random
-import numpy
from functools import partial
import torch
-import torchaudio
import torch.distributed as dist
+import torchaudio
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
-from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -28,7 +28,8 @@
class AudioDataset(IterableDataset):
- def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
+ def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
+ mode="train"):
self.scp_lists = scp_lists
self.data_names = data_names
self.data_types = data_types
@@ -40,6 +41,9 @@
self.world_size = 1
self.worker_id = 0
self.num_workers = 1
+ self.speed_perturb = speed_perturb
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
def set_epoch(self, epoch):
self.epoch = epoch
@@ -124,9 +128,15 @@
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend_conf["fs"])(waveform)
- sampling_rate = self.frontend_conf["fs"]
+ sampling_rate = self.frontend_conf["fs"]
waveform = waveform.numpy()
mat = waveform[0]
+ if self.speed_perturb is not None:
+ speed = random.choice(self.speed_perturb)
+ if speed != 1.0:
+ mat, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
+ mat = mat.view(-1).numpy()
sample_dict[data_name] = mat
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
@@ -168,6 +178,7 @@
bpe_tokenizer,
conf,
frontend_conf,
+ speed_perturb=None,
mode="train",
batch_mode="padding"):
scp_lists = read_lists(data_list_file)
@@ -196,7 +207,8 @@
data_names,
data_types,
frontend_conf=frontend_conf,
- shuffle=shuffle,
+ shuffle=shuffle,
+ speed_perturb=speed_perturb,
mode=mode,
)
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index f0f0c66..cf7d255 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -48,7 +48,7 @@
vad = -2
if bpe_tokenizer is not None:
- text = bpe_tokenizer.text2tokens("".join(text))
+ text = bpe_tokenizer.text2tokens(text)
if seg_dict is not None:
assert isinstance(seg_dict, dict)
diff --git a/funasr/datasets/small_datasets/collate_fn.py b/funasr/datasets/small_datasets/collate_fn.py
new file mode 100644
index 0000000..573f581
--- /dev/null
+++ b/funasr/datasets/small_datasets/collate_fn.py
@@ -0,0 +1,93 @@
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.modules.nets_utils import pad_list
+
+
+class CommonCollateFn:
+ """Functor class of common_collate_fn()"""
+
+ def __init__(
+ self,
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+ max_sample_size=None
+ ):
+ assert check_argument_types()
+ self.float_pad_value = float_pad_value
+ self.int_pad_value = int_pad_value
+ self.not_sequence = set(not_sequence)
+ self.max_sample_size = max_sample_size
+
+ def __repr__(self):
+ return (
+ f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+ f"int_pad_value={self.float_pad_value})"
+ )
+
+ def __call__(
+ self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+ ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ return common_collate_fn(
+ data,
+ float_pad_value=self.float_pad_value,
+ int_pad_value=self.int_pad_value,
+ not_sequence=self.not_sequence,
+ )
+
+
+def common_collate_fn(
+ data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ """Concatenate ndarray-list to an array and convert to torch.Tensor.
+ """
+ assert check_argument_types()
+ uttids = [u for u, _ in data]
+ data = [d for _, d in data]
+
+ assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+ assert all(
+ not k.endswith("_lengths") for k in data[0]
+ ), f"*_lengths is reserved: {list(data[0])}"
+
+ output = {}
+ for key in data[0]:
+ if data[0][key].dtype.kind == "i":
+ pad_value = int_pad_value
+ else:
+ pad_value = float_pad_value
+
+ array_list = [d[key] for d in data]
+ tensor_list = [torch.from_numpy(a) for a in array_list]
+ tensor = pad_list(tensor_list, pad_value)
+ output[key] = tensor
+
+ if key not in not_sequence:
+ lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+ output[key + "_lengths"] = lens
+
+ output = (uttids, output)
+ assert check_return_type(output)
+ return output
+
+def crop_to_max_size(feature, target_size):
+ size = len(feature)
+ diff = size - target_size
+ if diff <= 0:
+ return feature
+
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return feature[start:end]
\ No newline at end of file
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
new file mode 100644
index 0000000..123f109
--- /dev/null
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -0,0 +1,269 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import collections
+import copy
+import logging
+import numbers
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Mapping
+from typing import Union, List, Tuple
+
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.npy_scp import NpyScpReader
+from funasr.fileio.sound_scp import SoundScpReader
+
+
+class AdapterForSoundScpReader(collections.abc.Mapping):
+ def __init__(self, loader, dtype=None):
+ assert check_argument_types()
+ self.loader = loader
+ self.dtype = dtype
+ self.rate = None
+
+ def keys(self):
+ return self.loader.keys()
+
+ def __len__(self):
+ return len(self.loader)
+
+ def __iter__(self):
+ return iter(self.loader)
+
+ def __getitem__(self, key: str) -> np.ndarray:
+ retval = self.loader[key]
+
+ if isinstance(retval, tuple):
+ assert len(retval) == 2, len(retval)
+ if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+ # sound scp case
+ rate, array = retval
+ elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+ # Extended ark format case
+ array, rate = retval
+ else:
+ raise RuntimeError(
+ f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
+ )
+
+ if self.rate is not None and self.rate != rate:
+ raise RuntimeError(
+ f"Sampling rates are mismatched: {self.rate} != {rate}"
+ )
+ self.rate = rate
+ # Multichannel wave fie
+ # array: (NSample, Channel) or (Nsample)
+ if self.dtype is not None:
+ array = array.astype(self.dtype)
+
+ else:
+ # Normal ark case
+ assert isinstance(retval, np.ndarray), type(retval)
+ array = retval
+ if self.dtype is not None:
+ array = array.astype(self.dtype)
+
+ assert isinstance(array, np.ndarray), type(array)
+ return array
+
+
+def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
+ # The file is as follows:
+ # utterance_id_A /some/where/a.wav
+ # utterance_id_B /some/where/a.flac
+
+ # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
+ # like Kaldi e.g. "cat a.wav |".
+ # NOTE(kamo): The audio signal is normalized to [-1,1] range.
+ loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+
+ # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
+ # but ndarray is desired, so Adapter class is inserted here
+ return AdapterForSoundScpReader(loader, float_dtype)
+
+
+def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
+ loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
+ return AdapterForSoundScpReader(loader, float_dtype)
+
+
+class ESPnetDataset(Dataset):
+ """
+ Pytorch Dataset class for FunASR, modified from ESPnet
+ """
+
+ def __init__(
+ self,
+ path_name_type_list: Collection[Tuple[str, str, str]],
+ preprocess: Callable[
+ [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+ ] = None,
+ float_dtype: str = "float32",
+ int_dtype: str = "long",
+ dest_sample_rate: int = 16000,
+ speed_perturb: Union[list, tuple] = None,
+ mode: str = "train",
+ ):
+ assert check_argument_types()
+ if len(path_name_type_list) == 0:
+ raise ValueError(
+ '1 or more elements are required for "path_name_type_list"'
+ )
+
+ path_name_type_list = copy.deepcopy(path_name_type_list)
+ self.preprocess = preprocess
+
+ self.float_dtype = float_dtype
+ self.int_dtype = int_dtype
+ self.dest_sample_rate = dest_sample_rate
+ self.speed_perturb = speed_perturb
+ self.mode = mode
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
+
+ self.loader_dict = {}
+ self.debug_info = {}
+ for path, name, _type in path_name_type_list:
+ if name in self.loader_dict:
+ raise RuntimeError(f'"{name}" is duplicated for data-key')
+
+ loader = self._build_loader(path, _type)
+ self.loader_dict[name] = loader
+ self.debug_info[name] = path, _type
+ if len(self.loader_dict[name]) == 0:
+ raise RuntimeError(f"{path} has no samples")
+
+ def _build_loader(
+ self, path: str, loader_type: str
+ ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
+ """Helper function to instantiate Loader.
+
+ Args:
+ path: The file path
+ loader_type: loader_type. sound, npy, text, etc
+ """
+ if loader_type == "sound":
+ speed_perturb = self.speed_perturb if self.mode == "train" else None
+ loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
+ speed_perturb=speed_perturb)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "kaldi_ark":
+ loader = kaldiio.load_scp(path)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "npy":
+ return NpyScpReader(path)
+ elif loader_type == "text":
+ text_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
+ else:
+ k, v = sps
+ if k in text_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_loader[k] = v
+ return text_loader
+ elif loader_type == "text_int":
+ text_int_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
+ else:
+ k, v = sps
+ if k in text_int_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_int_loader[k] = [int(i) for i in v.split()]
+ return text_int_loader
+ else:
+ raise RuntimeError(f"Not supported: loader_type={loader_type}")
+
+ def has_name(self, name) -> bool:
+ return name in self.loader_dict
+
+ def names(self) -> Tuple[str, ...]:
+ return tuple(self.loader_dict)
+
+ def __iter__(self):
+ return iter(next(iter(self.loader_dict.values())))
+
+ def __repr__(self):
+ _mes = self.__class__.__name__
+ _mes += "("
+ for name, (path, _type) in self.debug_info.items():
+ _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
+ _mes += f"\n preprocess: {self.preprocess})"
+ return _mes
+
+ def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
+ assert check_argument_types()
+
+ # Change integer-id to string-id
+ if isinstance(uid, int):
+ d = next(iter(self.loader_dict.values()))
+ uid = list(d)[uid]
+
+ data = {}
+ # 1. Load data from each loaders
+ for name, loader in self.loader_dict.items():
+ try:
+ value = loader[uid]
+ if isinstance(value, (list, tuple)):
+ value = np.array(value)
+ if not isinstance(
+ value, (np.ndarray, torch.Tensor, str, numbers.Number)
+ ):
+ raise TypeError(
+ f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
+ )
+ except Exception:
+ path, _type = self.debug_info[name]
+ logging.error(
+ f"Error happened with path={path}, type={_type}, id={uid}"
+ )
+ raise
+
+ # torch.Tensor is converted to ndarray
+ if isinstance(value, torch.Tensor):
+ value = value.numpy()
+ elif isinstance(value, numbers.Number):
+ value = np.array([value])
+ data[name] = value
+
+ # 2. [Option] Apply preprocessing
+ # e.g. funasr.train.preprocessor:CommonPreprocessor
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+
+ # 3. Force data-precision
+ for name in data:
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == "i":
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ data[name] = value
+
+ retval = uid, data
+ assert check_return_type(retval)
+ return retval
diff --git a/funasr/datasets/small_datasets/length_batch_sampler.py b/funasr/datasets/small_datasets/length_batch_sampler.py
new file mode 100644
index 0000000..8ee8bdc
--- /dev/null
+++ b/funasr/datasets/small_datasets/length_batch_sampler.py
@@ -0,0 +1,147 @@
+from typing import Iterator
+from typing import List
+from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from typeguard import check_argument_types
+
+from funasr.fileio.read_text import load_num_sequence_text
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class LengthBatchSampler(AbsSampler):
+ def __init__(
+ self,
+ batch_bins: int,
+ shape_files: Union[Tuple[str, ...], List[str], Dict],
+ min_batch_size: int = 1,
+ sort_in_batch: str = "descending",
+ sort_batch: str = "ascending",
+ drop_last: bool = False,
+ padding: bool = True,
+ ):
+ assert check_argument_types()
+ assert batch_bins > 0
+ if sort_batch != "ascending" and sort_batch != "descending":
+ raise ValueError(
+ f"sort_batch must be ascending or descending: {sort_batch}"
+ )
+ if sort_in_batch != "descending" and sort_in_batch != "ascending":
+ raise ValueError(
+ f"sort_in_batch must be ascending or descending: {sort_in_batch}"
+ )
+
+ self.batch_bins = batch_bins
+ self.shape_files = shape_files
+ self.sort_in_batch = sort_in_batch
+ self.sort_batch = sort_batch
+ self.drop_last = drop_last
+
+ # utt2shape: (Length, ...)
+ # uttA 100,...
+ # uttB 201,...
+ if isinstance(shape_files, dict):
+ utt2shapes = [shape_files]
+ else:
+ utt2shapes = [
+ load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+ ]
+
+ first_utt2shape = utt2shapes[0]
+ for s, d in zip(shape_files, utt2shapes):
+ if set(d) != set(first_utt2shape):
+ raise RuntimeError(
+ f"keys are mismatched between {s} != {shape_files[0]}"
+ )
+
+ # Sort samples in ascending order
+ # (shape order should be like (Length, Dim))
+ keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
+ if len(keys) == 0:
+ raise RuntimeError(f"0 lines found: {shape_files[0]}")
+
+ # Decide batch-sizes
+ batch_sizes = []
+ current_batch_keys = []
+ for key in keys:
+ current_batch_keys.append(key)
+ # shape: (Length, dim1, dim2, ...)
+ if padding:
+ # bins = bs x max_length
+ bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
+ else:
+ # bins = sum of lengths
+ bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
+
+ if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
+ batch_sizes.append(len(current_batch_keys))
+ current_batch_keys = []
+ else:
+ if len(current_batch_keys) != 0 and (
+ not self.drop_last or len(batch_sizes) == 0
+ ):
+ batch_sizes.append(len(current_batch_keys))
+
+ if len(batch_sizes) == 0:
+ # Maybe we can't reach here
+ raise RuntimeError("0 batches")
+
+ # If the last batch-size is smaller than minimum batch_size,
+ # the samples are redistributed to the other mini-batches
+ if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
+ for i in range(batch_sizes.pop(-1)):
+ batch_sizes[-(i % len(batch_sizes)) - 1] += 1
+
+ if not self.drop_last:
+ # Bug check
+ assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
+
+ # Set mini-batch
+ self.batch_list = []
+ iter_bs = iter(batch_sizes)
+ bs = next(iter_bs)
+ minibatch_keys = []
+ for key in keys:
+ minibatch_keys.append(key)
+ if len(minibatch_keys) == bs:
+ if sort_in_batch == "descending":
+ minibatch_keys.reverse()
+ elif sort_in_batch == "ascending":
+ # Key are already sorted in ascending
+ pass
+ else:
+ raise ValueError(
+ "sort_in_batch must be ascending"
+ f" or descending: {sort_in_batch}"
+ )
+ self.batch_list.append(tuple(minibatch_keys))
+ minibatch_keys = []
+ try:
+ bs = next(iter_bs)
+ except StopIteration:
+ break
+
+ if sort_batch == "ascending":
+ pass
+ elif sort_batch == "descending":
+ self.batch_list.reverse()
+ else:
+ raise ValueError(
+ f"sort_batch must be ascending or descending: {sort_batch}"
+ )
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}("
+ f"N-batch={len(self)}, "
+ f"batch_bins={self.batch_bins}, "
+ f"sort_in_batch={self.sort_in_batch}, "
+ f"sort_batch={self.sort_batch})"
+ )
+
+ def __len__(self):
+ return len(self.batch_list)
+
+ def __iter__(self) -> Iterator[Tuple[str, ...]]:
+ return iter(self.batch_list)
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
new file mode 100644
index 0000000..d80f48a
--- /dev/null
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -0,0 +1,875 @@
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import scipy.signal
+import soundfile
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.token_id_converter import TokenIDConverter
+
+
+class AbsPreprocessor(ABC):
+ def __init__(self, train: bool):
+ self.train = train
+
+ @abstractmethod
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ raise NotImplementedError
+
+
+def forward_segment(text, dic):
+ word_list = []
+ i = 0
+ while i < len(text):
+ longest_word = text[i]
+ for j in range(i + 1, len(text) + 1):
+ word = text[i:j]
+ if word in dic:
+ if len(word) > len(longest_word):
+ longest_word = word
+ word_list.append(longest_word)
+ i += len(longest_word)
+ return word_list
+
+
+def seg_tokenize(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+
+def seg_tokenize_wo_pattern(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+
+def framing(
+ x,
+ frame_length: int = 512,
+ frame_shift: int = 256,
+ centered: bool = True,
+ padded: bool = True,
+):
+ if x.size == 0:
+ raise ValueError("Input array size is zero")
+ if frame_length < 1:
+ raise ValueError("frame_length must be a positive integer")
+ if frame_length > x.shape[-1]:
+ raise ValueError("frame_length is greater than input length")
+ if 0 >= frame_shift:
+ raise ValueError("frame_shift must be greater than 0")
+
+ if centered:
+ pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
+ (frame_length // 2, frame_length // 2)
+ ]
+ x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+ if padded:
+ # Pad to integer number of windowed segments
+ # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
+ # with integer nseg
+ nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
+ pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
+ x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+ # Created strided array of data segments
+ if frame_length == 1 and frame_length == frame_shift:
+ result = x[..., None]
+ else:
+ shape = x.shape[:-1] + (
+ (x.shape[-1] - frame_length) // frame_shift + 1,
+ frame_length,
+ )
+ strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
+ result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return result
+
+
+def detect_non_silence(
+ x: np.ndarray,
+ threshold: float = 0.01,
+ frame_length: int = 1024,
+ frame_shift: int = 512,
+ window: str = "boxcar",
+) -> np.ndarray:
+ """Power based voice activity detection.
+
+ Args:
+ x: (Channel, Time)
+ >>> x = np.random.randn(1000)
+ >>> detect = detect_non_silence(x)
+ >>> assert x.shape == detect.shape
+ >>> assert detect.dtype == np.bool
+ """
+ if x.shape[-1] < frame_length:
+ return np.full(x.shape, fill_value=True, dtype=np.bool)
+
+ if x.dtype.kind == "i":
+ x = x.astype(np.float64)
+ # framed_w: (C, T, F)
+ framed_w = framing(
+ x,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ centered=False,
+ padded=True,
+ )
+ framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
+ # power: (C, T)
+ power = (framed_w ** 2).mean(axis=-1)
+ # mean_power: (C, 1)
+ mean_power = np.mean(power, axis=-1, keepdims=True)
+ if np.all(mean_power == 0):
+ return np.full(x.shape, fill_value=True, dtype=np.bool)
+ # detect_frames: (C, T)
+ detect_frames = power / mean_power > threshold
+ # detects: (C, T, F)
+ detects = np.broadcast_to(
+ detect_frames[..., None], detect_frames.shape + (frame_shift,)
+ )
+ # detects: (C, TF)
+ detects = detects.reshape(*detect_frames.shape[:-1], -1)
+ # detects: (C, TF)
+ return np.pad(
+ detects,
+ [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
+ mode="edge",
+ )
+
+
+class CommonPreprocessor(AbsPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train)
+ self.train = train
+ self.speech_name = speech_name
+ self.text_name = text_name
+ self.speech_volume_normalize = speech_volume_normalize
+ self.rir_apply_prob = rir_apply_prob
+ self.noise_apply_prob = noise_apply_prob
+ self.split_with_space = split_with_space
+ self.seg_dict = None
+ if seg_dict_file is not None:
+ self.seg_dict = {}
+ with open(seg_dict_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ self.seg_dict[key] = " ".join(value)
+
+ if token_type is not None:
+ if token_list is None:
+ raise ValueError("token_list is required if token_type is not None")
+ self.text_cleaner = TextCleaner(text_cleaner)
+
+ self.tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ self.token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+ )
+ else:
+ self.text_cleaner = None
+ self.tokenizer = None
+ self.token_id_converter = None
+
+ if train and rir_scp is not None:
+ self.rirs = []
+ with open(rir_scp, "r", encoding="utf-8") as f:
+ for line in f:
+ sps = line.strip().split(None, 1)
+ if len(sps) == 1:
+ self.rirs.append(sps[0])
+ else:
+ self.rirs.append(sps[1])
+ else:
+ self.rirs = None
+
+ if train and noise_scp is not None:
+ self.noises = []
+ with open(noise_scp, "r", encoding="utf-8") as f:
+ for line in f:
+ sps = line.strip().split(None, 1)
+ if len(sps) == 1:
+ self.noises.append(sps[0])
+ else:
+ self.noises.append(sps[1])
+ sps = noise_db_range.split("_")
+ if len(sps) == 1:
+ self.noise_db_low, self.noise_db_high = float(sps[0])
+ elif len(sps) == 2:
+ self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
+ else:
+ raise ValueError(
+ "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
+ )
+ else:
+ self.noises = None
+
+ def _speech_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, Union[str, np.ndarray]]:
+ assert check_argument_types()
+ if self.speech_name in data:
+ if self.train and (self.rirs is not None or self.noises is not None):
+ speech = data[self.speech_name]
+ nsamples = len(speech)
+
+ # speech: (Nmic, Time)
+ if speech.ndim == 1:
+ speech = speech[None, :]
+ else:
+ speech = speech.T
+ # Calc power on non shlence region
+ power = (speech[detect_non_silence(speech)] ** 2).mean()
+
+ # 1. Convolve RIR
+ if self.rirs is not None and self.rir_apply_prob >= np.random.random():
+ rir_path = np.random.choice(self.rirs)
+ if rir_path is not None:
+ rir, _ = soundfile.read(
+ rir_path, dtype=np.float64, always_2d=True
+ )
+
+ # rir: (Nmic, Time)
+ rir = rir.T
+
+ # speech: (Nmic, Time)
+ # Note that this operation doesn't change the signal length
+ speech = scipy.signal.convolve(speech, rir, mode="full")[
+ :, : speech.shape[1]
+ ]
+ # Reverse mean power to the original power
+ power2 = (speech[detect_non_silence(speech)] ** 2).mean()
+ speech = np.sqrt(power / max(power2, 1e-10)) * speech
+
+ # 2. Add Noise
+ if (
+ self.noises is not None
+ and self.noise_apply_prob >= np.random.random()
+ ):
+ noise_path = np.random.choice(self.noises)
+ if noise_path is not None:
+ noise_db = np.random.uniform(
+ self.noise_db_low, self.noise_db_high
+ )
+ with soundfile.SoundFile(noise_path) as f:
+ if f.frames == nsamples:
+ noise = f.read(dtype=np.float64, always_2d=True)
+ elif f.frames < nsamples:
+ offset = np.random.randint(0, nsamples - f.frames)
+ # noise: (Time, Nmic)
+ noise = f.read(dtype=np.float64, always_2d=True)
+ # Repeat noise
+ noise = np.pad(
+ noise,
+ [(offset, nsamples - f.frames - offset), (0, 0)],
+ mode="wrap",
+ )
+ else:
+ offset = np.random.randint(0, f.frames - nsamples)
+ f.seek(offset)
+ # noise: (Time, Nmic)
+ noise = f.read(
+ nsamples, dtype=np.float64, always_2d=True
+ )
+ if len(noise) != nsamples:
+ raise RuntimeError(f"Something wrong: {noise_path}")
+ # noise: (Nmic, Time)
+ noise = noise.T
+
+ noise_power = (noise ** 2).mean()
+ scale = (
+ 10 ** (-noise_db / 20)
+ * np.sqrt(power)
+ / np.sqrt(max(noise_power, 1e-10))
+ )
+ speech = speech + scale * noise
+
+ speech = speech.T
+ ma = np.max(np.abs(speech))
+ if ma > 1.0:
+ speech /= ma
+ data[self.speech_name] = speech
+
+ if self.speech_volume_normalize is not None:
+ speech = data[self.speech_name]
+ ma = np.max(np.abs(speech))
+ data[self.speech_name] = speech * self.speech_volume_normalize / ma
+ assert check_return_type(data)
+ return data
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = forward_segment("".join(tokens), self.seg_dict)
+ tokens = seg_tokenize(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ assert check_argument_types()
+
+ data = self._speech_process(data)
+ data = self._text_process(data)
+ return data
+
+
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train,
+ token_type,
+ token_list,
+ bpemodel,
+ text_cleaner,
+ g2p_type,
+ unk_symbol,
+ space_symbol,
+ non_linguistic_symbols,
+ delimiter,
+ rir_scp,
+ rir_apply_prob,
+ noise_scp,
+ noise_apply_prob,
+ noise_db_range,
+ speech_volume_normalize,
+ speech_name,
+ text_name,
+ split_with_space,
+ seg_dict_file,
+ )
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+
+class CommonPreprocessor_multi(AbsPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ ):
+ super().__init__(train)
+ self.train = train
+ self.speech_name = speech_name
+ self.text_name = text_name
+
+ if token_type is not None:
+ if token_list is None:
+ raise ValueError("token_list is required if token_type is not None")
+ self.text_cleaner = TextCleaner(text_cleaner)
+
+ self.tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ self.token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+ )
+ else:
+ self.text_cleaner = None
+ self.tokenizer = None
+ self.token_id_converter = None
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for text_n in self.text_name:
+ if text_n in data and self.tokenizer is not None:
+ text = data[text_n]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[text_n] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ assert check_argument_types()
+
+ if self.speech_name in data:
+ # Nothing now: candidates:
+ # - STFT
+ # - Fbank
+ # - CMVN
+ # - Data augmentation
+ pass
+
+ data = self._text_process(data)
+ return data
+
+
+class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: List[str] = [None],
+ token_list: List[Union[Path, str, Iterable[str]]] = [None],
+ bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ ):
+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+ super().__init__(
+ train=train,
+ token_type=token_type[0],
+ token_list=token_list[0],
+ bpemodel=bpemodel[0],
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name[0],
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ )
+
+ assert (
+ len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+ self.num_tokenizer = len(token_type)
+ self.tokenizer = []
+ self.token_id_converter = []
+
+ for i in range(self.num_tokenizer):
+ if token_type[i] is not None:
+ if token_list[i] is None:
+ raise ValueError("token_list is required if token_type is not None")
+
+ self.tokenizer.append(
+ build_tokenizer(
+ token_type=token_type[i],
+ bpemodel=bpemodel[i],
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ )
+ self.token_id_converter.append(
+ TokenIDConverter(
+ token_list=token_list[i],
+ unk_symbol=unk_symbol,
+ )
+ )
+ else:
+ self.tokenizer.append(None)
+ self.token_id_converter.append(None)
+
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.text_name = text_name # override the text_name from CommonPreprocessor
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for i in range(self.num_tokenizer):
+ text_name = self.text_name[i]
+ if text_name in data and self.tokenizer[i] is not None:
+ text = data[text_name]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer[i].text2tokens(text)
+ text_ints = self.token_id_converter[i].tokens2ids(tokens)
+ data[text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+
+class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_text_name: str = "split_text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(
+ train=train,
+ # Force to use word.
+ token_type="word",
+ token_list=token_list,
+ bpemodel=bpemodel,
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name,
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ )
+ # The data field name for split text.
+ self.split_text_name = split_text_name
+
+ @classmethod
+ def split_words(cls, text: str):
+ words = []
+ segs = text.split()
+ for seg in segs:
+ # There is no space in seg.
+ current_word = ""
+ for c in seg:
+ if len(c.encode()) == 1:
+ # This is an ASCII char.
+ current_word += c
+ else:
+ # This is a Chinese char.
+ if len(current_word) > 0:
+ words.append(current_word)
+ current_word = ""
+ words.append(c)
+ if len(current_word) > 0:
+ words.append(current_word)
+ return words
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
+ ) -> Dict[str, Union[list, np.ndarray]]:
+ assert check_argument_types()
+ # Split words.
+ if isinstance(data[self.text_name], str):
+ split_text = self.split_words(data[self.text_name])
+ else:
+ split_text = data[self.text_name]
+ data[self.text_name] = " ".join(split_text)
+ data = self._speech_process(data)
+ data = self._text_process(data)
+ data[self.split_text_name] = split_text
+ return data
+
+ def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
+ result = data[self.split_text_name]
+ del data[self.split_text_name]
+ return result
+
+
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: List[str] = [None],
+ token_list: List[Union[Path, str, Iterable[str]]] = [None],
+ bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ vad_name: str = "vad_indexes",
+ ):
+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+ super().__init__(
+ train=train,
+ token_type=token_type[0],
+ token_list=token_list[0],
+ bpemodel=bpemodel[0],
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name[0],
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ )
+
+ assert (
+ len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+ self.num_tokenizer = len(token_type)
+ self.tokenizer = []
+ self.token_id_converter = []
+
+ for i in range(self.num_tokenizer):
+ if token_type[i] is not None:
+ if token_list[i] is None:
+ raise ValueError("token_list is required if token_type is not None")
+
+ self.tokenizer.append(
+ build_tokenizer(
+ token_type=token_type[i],
+ bpemodel=bpemodel[i],
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ )
+ self.token_id_converter.append(
+ TokenIDConverter(
+ token_list=token_list[i],
+ unk_symbol=unk_symbol,
+ )
+ )
+ else:
+ self.tokenizer.append(None)
+ self.token_id_converter.append(None)
+
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.text_name = text_name # override the text_name from CommonPreprocessor
+ self.vad_name = vad_name
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for i in range(self.num_tokenizer):
+ text_name = self.text_name[i]
+ if text_name in data and self.tokenizer[i] is not None:
+ text = data[text_name]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer[i].text2tokens(text)
+ if "vad:" in tokens[-1]:
+ vad = tokens[-1][4:]
+ tokens = tokens[:-1]
+ if len(vad) == 0:
+ vad = -1
+ else:
+ vad = int(vad)
+ data[self.vad_name] = np.array([vad], dtype=np.int64)
+ text_ints = self.token_id_converter[i].tokens2ids(tokens)
+ data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
+
+
+def build_preprocess(args, train):
+ if not args.use_preprocessor:
+ return None
+ if args.task_name in ["asr", "data2vec", "diar", "sv"]:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob if hasattr(args, "rir_apply_prob") else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob if hasattr(args, "noise_apply_prob") else 1.0,
+ noise_db_range=args.noise_db_range if hasattr(args, "noise_db_range") else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize if hasattr(args, "rir_scp") else None,
+ )
+ elif args.task_name == "punc":
+ token_types = [args.token_type, args.token_type]
+ token_lists = [args.token_list, args.punc_list]
+ bpemodels = [args.bpemodel, args.bpemodel]
+ text_names = ["text", "punc"]
+ retval = PuncTrainTokenizerCommonPreprocessor(
+ train=train,
+ token_type=token_types,
+ token_list=token_lists,
+ bpemodel=bpemodels,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ text_name=text_names,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ )
+ elif args.task_name == "lm":
+ retval = LMPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ text_name="text",
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ split_with_space=args.split_with_space,
+ seg_dict_file=args.seg_dict_file
+ )
+ elif args.task_name == "vad":
+ retval = None
+ else:
+ raise ValueError(f"Not supported task={args.task_name}")
+ return retval
diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py
new file mode 100644
index 0000000..3ebcc5a
--- /dev/null
+++ b/funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -0,0 +1,189 @@
+import logging
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.dataset import ESPnetDataset
+from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class RawSampler(AbsSampler):
+ def __init__(self, batches):
+ self.batches = batches
+
+ def __len__(self):
+ return len(self.batches)
+
+ def __iter__(self):
+ return iter(self.batches)
+
+ def generate(self, seed):
+ return list(self.batches)
+
+
+class SequenceIterFactory(AbsIterFactory):
+ """Build iterator for each epoch, modified from ESPnet
+
+ """
+
+ def __init__(self, args, mode="train"):
+
+ # preprocess
+ preprocess_fn = build_preprocess(args, train=mode == "train")
+
+ # collate
+ if args.task_name in ["punc", "lm"]:
+ collate_fn = CommonCollateFn(int_pad_value=0)
+ else:
+ collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ # dataset
+ dest_sample_rate = args.frontend_conf["fs"] if (
+ args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
+ if mode == "train":
+ data_path_and_name_and_type = args.train_data_path_and_name_and_type
+ shape_files = args.train_shape_file
+ elif mode == "valid":
+ data_path_and_name_and_type = args.valid_data_path_and_name_and_type
+ shape_files = args.valid_shape_file
+ else:
+ raise NotImplementedError(f"mode={mode}")
+ dataset = ESPnetDataset(
+ data_path_and_name_and_type,
+ preprocess=preprocess_fn,
+ dest_sample_rate=dest_sample_rate,
+ speed_perturb=args.speed_perturb if mode=="train" else None,
+ )
+
+ # sampler
+ dataset_conf = args.dataset_conf
+ batch_sampler = LengthBatchSampler(
+ batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
+ shape_files=shape_files,
+ sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
+ sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
+ drop_last=False,
+ padding=True,
+ )
+
+ batches = list(batch_sampler)
+ bs_list = [len(batch) for batch in batches]
+ logging.info(f"[{mode}] dataset:\n{dataset}")
+ logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
+ logging.info(
+ f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
+ f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
+ )
+
+ if args.scheduler == "tri_stage" and mode == "train":
+ args.max_update = len(bs_list) * args.max_epoch
+ logging.info("Max update: {}".format(args.max_update))
+
+ if args.distributed and mode=="train":
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+ for batch in batches:
+ if len(batch) < world_size:
+ raise RuntimeError(
+ f"The batch-size must be equal or more than world_size: "
+ f"{len(batch)} < {world_size}"
+ )
+ batches = [batch[rank::world_size] for batch in batches]
+
+ if not isinstance(batches, AbsSampler):
+ self.sampler = RawSampler(batches)
+ else:
+ self.sampler = batches
+
+ self.dataset = dataset
+ self.num_iters_per_epoch = None
+ self.shuffle = mode == "train"
+ self.seed = args.seed
+ self.num_workers = args.dataset_conf.get("num_workers", 8)
+ self.collate_fn = collate_fn
+ self.pin_memory = args.ngpu > 0
+
+ def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
+ if shuffle is None:
+ shuffle = self.shuffle
+
+ if self.num_iters_per_epoch is not None:
+ N = len(self.sampler)
+ # If corpus size is larger than the num_per_epoch
+ if self.num_iters_per_epoch < N:
+ N = len(self.sampler)
+ real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
+
+ if offset >= self.num_iters_per_epoch:
+ current_batches = self.sampler.generate(real_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(real_epoch + self.seed).shuffle(
+ current_batches
+ )
+ batches = current_batches[
+ offset - self.num_iters_per_epoch: offset
+ ]
+ else:
+ prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
+ current_batches = self.sampler.generate(real_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
+ prev_batches
+ )
+ np.random.RandomState(real_epoch + self.seed).shuffle(
+ current_batches
+ )
+ batches = (
+ prev_batches[offset - self.num_iters_per_epoch:]
+ + current_batches[:offset]
+ )
+
+ # If corpus size is less than the num_per_epoch
+ else:
+ _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
+ _remain = self.num_iters_per_epoch
+ batches = []
+ current_batches = self.sampler.generate(_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
+ while _remain > 0:
+
+ _batches = current_batches[_cursor: _cursor + _remain]
+ batches += _batches
+ if _cursor + _remain >= N:
+ _epoch += 1
+ _cursor = 0
+ current_batches = self.sampler.generate(_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(_epoch + self.seed).shuffle(
+ current_batches
+ )
+ else:
+ _cursor = _cursor + _remain
+ _remain -= len(_batches)
+
+ assert len(batches) == self.num_iters_per_epoch
+
+ else:
+ batches = self.sampler.generate(epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(epoch + self.seed).shuffle(batches)
+
+ # For backward compatibility for pytorch DataLoader
+ if self.collate_fn is not None:
+ kwargs = dict(collate_fn=self.collate_fn)
+ else:
+ kwargs = {}
+
+ return DataLoader(
+ dataset=self.dataset,
+ batch_sampler=batches,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ **kwargs,
+ )
diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py
index 86be026..507226e 100644
--- a/funasr/export/test/test_onnx_punc_vadrealtime.py
+++ b/funasr/export/test/test_onnx_punc_vadrealtime.py
@@ -12,7 +12,7 @@
return {'inputs': np.ones((1, text_length), dtype=np.int64),
'text_lengths': np.array([text_length,], dtype=np.int32),
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
- 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+ 'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
}
def _run(feed_dict):
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index d757f7f..c752fe6 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -2,10 +2,14 @@
from pathlib import Path
from typing import Union
+import random
import numpy as np
import soundfile
import librosa
from typeguard import check_argument_types
+
+import torch
+import torchaudio
from funasr.fileio.read_text import read_2column_text
@@ -32,6 +36,7 @@
always_2d: bool = False,
normalize: bool = False,
dest_sample_rate: int = 16000,
+ speed_perturb: Union[list, tuple] = None,
):
assert check_argument_types()
self.fname = fname
@@ -40,6 +45,7 @@
self.normalize = normalize
self.data = read_2column_text(fname)
self.dest_sample_rate = dest_sample_rate
+ self.speed_perturb = speed_perturb
def __getitem__(self, key):
wav = self.data[key]
@@ -53,8 +59,17 @@
wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
)
+ if self.speed_perturb is not None:
+ speed = random.choice(self.speed_perturb)
+ if speed != 1.0:
+ array, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(array).view(1, -1), rate,
+ [['speed', str(speed)], ['rate', str(rate)]])
+ array = array.view(-1).numpy()
+
if array.ndim==2:
array=array.transpose((1, 0))
+
return rate, array
def get_path(self, key):
diff --git a/funasr/layers/abs_normalize.py b/funasr/layers/abs_normalize.py
index f2be748..4e617d0 100644
--- a/funasr/layers/abs_normalize.py
+++ b/funasr/layers/abs_normalize.py
@@ -11,4 +11,4 @@
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/layers/global_mvn.py b/funasr/layers/global_mvn.py
index 5515cdd..8e43582 100644
--- a/funasr/layers/global_mvn.py
+++ b/funasr/layers/global_mvn.py
@@ -13,9 +13,7 @@
class GlobalMVN(AbsNormalize, InversibleInterface):
"""Apply global mean and variance normalization
-
TODO(kamo): Make this class portable somehow
-
Args:
stats_file: npy file
norm_means: Apply mean normalization
@@ -66,7 +64,6 @@
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
-
Args:
x: (B, L, ...)
ilens: (B,)
@@ -118,4 +115,4 @@
if norm_means:
x += self.mean
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
- return x, ilens
+ return x, ilens
\ No newline at end of file
diff --git a/funasr/layers/inversible_interface.py b/funasr/layers/inversible_interface.py
index a1a5939..657ec68 100644
--- a/funasr/layers/inversible_interface.py
+++ b/funasr/layers/inversible_interface.py
@@ -11,4 +11,4 @@
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/lm/abs_model.py b/funasr/lm/abs_model.py
deleted file mode 100644
index 1f3c8a7..0000000
--- a/funasr/lm/abs_model.py
+++ /dev/null
@@ -1,158 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Tuple
-
-import torch
-
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from typing import Dict
-from typing import Optional
-from typing import Tuple
-
-import torch
-import torch.nn.functional as F
-from typeguard import check_argument_types
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
- """The abstract LM class
-
- To share the loss calculation way among different models,
- We uses delegate pattern here:
- The instance of this class should be passed to "LanguageModel"
-
- >>> from funasr.lm.abs_model import AbsLM
- >>> lm = AbsLM()
- >>> model = LanguageESPnetModel(lm=lm)
-
- This "model" is one of mediator objects for "Task" class.
-
- """
-
- @abstractmethod
- def forward(
- self, input: torch.Tensor, hidden: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
-
-class LanguageModel(AbsESPnetModel):
- def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
- assert check_argument_types()
- super().__init__()
- self.lm = lm
- self.sos = 1
- self.eos = 2
-
- # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
- self.ignore_id = ignore_id
-
- def nll(
- self,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- max_length: Optional[int] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute negative log likelihood(nll)
-
- Normally, this function is called in batchify_nll.
- Args:
- text: (Batch, Length)
- text_lengths: (Batch,)
- max_lengths: int
- """
- batch_size = text.size(0)
- # For data parallel
- if max_length is None:
- text = text[:, : text_lengths.max()]
- else:
- text = text[:, :max_length]
-
- # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
- # text: (Batch, Length) -> x, y: (Batch, Length + 1)
- x = F.pad(text, [1, 0], "constant", self.sos)
- t = F.pad(text, [0, 1], "constant", self.ignore_id)
- for i, l in enumerate(text_lengths):
- t[i, l] = self.eos
- x_lengths = text_lengths + 1
-
- # 2. Forward Language model
- # x: (Batch, Length) -> y: (Batch, Length, NVocab)
- y, _ = self.lm(x, None)
-
- # 3. Calc negative log likelihood
- # nll: (BxL,)
- nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
- # nll: (BxL,) -> (BxL,)
- if max_length is None:
- nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
- else:
- nll.masked_fill_(
- make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
- 0.0,
- )
- # nll: (BxL,) -> (B, L)
- nll = nll.view(batch_size, -1)
- return nll, x_lengths
-
- def batchify_nll(
- self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute negative log likelihood(nll) from transformer language model
-
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- text: (Batch, Length)
- text_lengths: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
-
- """
- total_num = text.size(0)
- if total_num <= batch_size:
- nll, x_lengths = self.nll(text, text_lengths)
- else:
- nlls = []
- x_lengths = []
- max_length = text_lengths.max()
-
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_text = text[start_idx:end_idx, :]
- batch_text_lengths = text_lengths[start_idx:end_idx]
- # batch_nll: [B * T]
- batch_nll, batch_x_lengths = self.nll(
- batch_text, batch_text_lengths, max_length=max_length
- )
- nlls.append(batch_nll)
- x_lengths.append(batch_x_lengths)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nlls)
- x_lengths = torch.cat(x_lengths)
- assert nll.size(0) == total_num
- assert x_lengths.size(0) == total_num
- return nll, x_lengths
-
- def forward(
- self, text: torch.Tensor, text_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- nll, y_lengths = self.nll(text, text_lengths)
- ntokens = y_lengths.sum()
- loss = nll.sum() / ntokens
- stats = dict(loss=loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
- return loss, stats, weight
-
- def collect_feats(
- self, text: torch.Tensor, text_lengths: torch.Tensor
- ) -> Dict[str, torch.Tensor]:
- return {}
diff --git a/funasr/main_funcs/calculate_all_attentions.py b/funasr/main_funcs/calculate_all_attentions.py
index 8f238c6..c3bf015 100644
--- a/funasr/main_funcs/calculate_all_attentions.py
+++ b/funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@
from funasr.modules.attention import MultiHeadedAttention
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
@torch.no_grad()
def calculate_all_attentions(
- model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
+ model: FunASRModel, batch: Dict[str, torch.Tensor]
) -> Dict[str, List[torch.Tensor]]:
"""Derive the outputs from the all attention layers
diff --git a/funasr/main_funcs/collect_stats.py b/funasr/main_funcs/collect_stats.py
index bacda8f..584b85a 100644
--- a/funasr/main_funcs/collect_stats.py
+++ b/funasr/main_funcs/collect_stats.py
@@ -17,12 +17,12 @@
from funasr.fileio.npy_scp import NpyScpWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
@torch.no_grad()
def collect_stats(
- model: AbsESPnetModel,
+ model: FunASRModel,
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
output_dir: Path,
diff --git a/funasr/models/base_model.py b/funasr/models/base_model.py
new file mode 100644
index 0000000..80b3bbd
--- /dev/null
+++ b/funasr/models/base_model.py
@@ -0,0 +1,17 @@
+import torch
+
+
+class FunASRModel(torch.nn.Module):
+ """The common model class
+
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.num_updates = 0
+
+ def set_num_updates(self, num_updates):
+ self.num_updates = num_updates
+
+ def get_num_updates(self):
+ return self.num_updates
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
index fcd6bd2..e5bd640 100644
--- a/funasr/models/data2vec.py
+++ b/funasr/models/data2vec.py
@@ -18,7 +18,7 @@
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -29,7 +29,7 @@
yield
-class Data2VecPretrainModel(AbsESPnetModel):
+class Data2VecPretrainModel(FunASRModel):
"""Data2Vec Pretrain model"""
def __init__(
@@ -57,7 +57,6 @@
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -106,7 +105,6 @@
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index f64ea3d..e6e6a52 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -28,7 +28,7 @@
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -39,7 +39,7 @@
yield
-class ESPnetASRModel(AbsESPnetModel):
+class ASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -49,9 +49,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -64,6 +62,8 @@
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -133,7 +133,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -249,7 +248,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -331,9 +329,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -370,7 +366,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index f22f12a..fbf0d11 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -23,7 +23,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -35,7 +35,8 @@
import pdb
import random
import math
-class MFCCA(AbsESPnetModel):
+
+class MFCCA(FunASRModel):
"""
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
@@ -43,26 +44,26 @@
"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- rnnt_decoder: None,
- ctc_weight: float = 0.5,
- ignore_id: int = -1,
- lsm_weight: float = 0.0,
- mask_ratio: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ rnnt_decoder: None,
+ ctc_weight: float = 0.5,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ mask_ratio: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ preencoder: Optional[AbsPreEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -76,10 +77,9 @@
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.token_list = token_list.copy()
-
+
self.mask_ratio = mask_ratio
-
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
@@ -113,14 +113,13 @@
self.error_calculator = None
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -130,22 +129,22 @@
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- #pdb.set_trace()
- if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+ # pdb.set_trace()
+ if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
rate_num = random.random()
- #rate_num = 0.1
- if(rate_num<=self.mask_ratio):
- retain_channel = math.ceil(random.random() *8)
- if(retain_channel>1):
- speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+ # rate_num = 0.1
+ if (rate_num <= self.mask_ratio):
+ retain_channel = math.ceil(random.random() * 8)
+ if (retain_channel > 1):
+ speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
else:
- speech = speech[:,:,torch.randperm(8)[0]]
- #pdb.set_trace()
+ speech = speech[:, :, torch.randperm(8)[0]]
+ # pdb.set_trace()
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@@ -195,20 +194,19 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -227,14 +225,14 @@
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
- #pdb.set_trace()
+ # pdb.set_trace()
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
- if(encoder_out.dim()==4):
+ if (encoder_out.dim() == 4):
assert encoder_out.size(2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
@@ -248,7 +246,7 @@
return encoder_out, encoder_out_lens
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
@@ -266,11 +264,11 @@
return feats, feats_lengths, channel_size
def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
@@ -298,14 +296,14 @@
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
- if(encoder_out.dim()==4):
+ if (encoder_out.dim() == 4):
encoder_out = encoder_out.mean(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
@@ -317,10 +315,10 @@
return loss_ctc, cer_ctc
def _calc_rnnt_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index d02783f..9241271 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -29,9 +29,8 @@
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -42,7 +41,7 @@
yield
-class Paraformer(AbsESPnetModel):
+class Paraformer(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -56,9 +55,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -79,6 +76,9 @@
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
+ use_1st_decoder_loss: bool = False,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -145,6 +145,8 @@
if self.share_embedding:
self.decoder.embed = None
+ self.use_1st_decoder_loss = use_1st_decoder_loss
+
def forward(
self,
speech: torch.Tensor,
@@ -153,7 +155,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -181,7 +182,7 @@
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
@@ -222,7 +223,7 @@
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+ loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
@@ -234,8 +235,12 @@
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+ if self.use_1st_decoder_loss and pre_loss_att is not None:
+ loss = loss + pre_loss_att
+
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
@@ -270,7 +275,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -368,9 +372,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -407,7 +409,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -462,11 +463,16 @@
# 0. sampler
decoder_out_1st = None
+ pre_loss_att = None
if self.sampling_ratio > 0.0:
if self.step_cur < 2:
logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
+ if self.use_1st_decoder_loss:
+ sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
+ else:
+ sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+ pre_acoustic_embeds)
else:
if self.step_cur < 2:
logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
@@ -496,7 +502,7 @@
ys_hat = decoder_out_1st.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre
+ return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
@@ -528,6 +534,37 @@
sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
input_mask_expand_dim, 0)
return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+ def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+ tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+ ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+ if self.share_embedding:
+ ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+ else:
+ ys_pad_embed = self.decoder.embed(ys_pad_masked)
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+ )
+ pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
+ decoder_out, _ = decoder_outs[0], decoder_outs[1]
+ pred_tokens = decoder_out.argmax(-1)
+ nonpad_positions = ys_pad.ne(self.ignore_id)
+ seq_lens = (nonpad_positions).sum(1)
+ same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+ input_mask = torch.ones_like(nonpad_positions)
+ bsz, seq_len = ys_pad.size()
+ for li in range(bsz):
+ target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+ if target_num > 0:
+ input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+ input_mask = input_mask.eq(1)
+ input_mask = input_mask.masked_fill(~nonpad_positions, False)
+ input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+ sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+ input_mask_expand_dim, 0)
+
+ return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
def _calc_ctc_loss(
self,
@@ -664,7 +701,10 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -738,9 +778,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -763,6 +801,8 @@
embeds_id: int = 2,
embeds_loss_weight: float = 0.0,
embed_dims: int = 768,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -894,7 +934,6 @@
embed_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -913,9 +952,9 @@
self.step_cur += 1
# for data-parallel
text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max(), :]
+ speech = speech[:, :speech_lengths.max()]
if embed is not None:
- embed = embed[:, :embed_lengths.max(), :]
+ embed = embed[:, :embed_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -1003,74 +1042,73 @@
class BiCifParaformer(Paraformer):
-
"""
Paraformer model with an extra cif predictor
to conduct accurate timestamp prediction
"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor = None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
@@ -1145,21 +1183,23 @@
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, loss_pre
-
+
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
+
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
@@ -1170,7 +1210,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1253,7 +1292,8 @@
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1282,9 +1322,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -1314,6 +1352,8 @@
bias_encoder_type: str = 'lstm',
label_bracket: bool = False,
use_decoder_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1377,7 +1417,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1761,4 +1800,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index a5aaa6c..3120087 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -17,7 +17,7 @@
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
@@ -28,7 +28,7 @@
yield
-class TransducerModel(AbsESPnetModel):
+class TransducerModel(FunASRModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
@@ -483,7 +483,7 @@
return loss_lm
-class UnifiedTransducerModel(AbsESPnetModel):
+class UnifiedTransducerModel(FunASRModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
vocab_size: Size of complete vocabulary (w/ EOS and blank included).
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 097b23a..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -16,7 +16,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@@ -34,7 +34,7 @@
return att
-class DiarEENDOLAModel(AbsESPnetModel):
+class DiarEENDOLAModel(FunASRModel):
"""EEND-OLA diarization model"""
def __init__(
@@ -91,7 +91,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 3f7011d..9c3fb92 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -22,7 +22,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@@ -35,7 +35,7 @@
yield
-class DiarSondModel(AbsESPnetModel):
+class DiarSondModel(FunASRModel):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
@@ -115,7 +115,6 @@
binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
-
Args:
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
@@ -391,7 +390,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
@@ -491,4 +489,4 @@
speaker_miss,
speaker_falarm,
speaker_error,
- )
+ )
\ No newline at end of file
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index f694cc2..8304607 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -29,7 +29,7 @@
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -40,7 +40,7 @@
yield
-class ESPnetASRModel(AbsESPnetModel):
+class ESPnetASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index 5b21277..bd5178e 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -29,7 +29,7 @@
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -40,7 +40,7 @@
yield
-class ESPnetSVModel(AbsESPnetModel):
+class ESPnetSVModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -80,7 +80,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -221,7 +220,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -271,4 +269,4 @@
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
+ return feats, feats_lengths
\ No newline at end of file
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index d1367ab..33948f9 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -17,9 +17,8 @@
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -30,7 +29,7 @@
yield
-class TimestampPredictor(AbsESPnetModel):
+class TimestampPredictor(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
@@ -56,7 +55,7 @@
self.predictor_bias = predictor_bias
self.criterion_pre = mae_loss()
self.token_list = token_list
-
+
def forward(
self,
speech: torch.Tensor,
@@ -65,7 +64,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -113,7 +111,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -128,7 +125,7 @@
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out, encoder_out_lens
-
+
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -151,8 +148,8 @@
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def collect_feats(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index ca76244..d08ea37 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -25,7 +25,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
from funasr.models.predictor.cif import mae_loss
@@ -38,7 +38,7 @@
yield
-class UniASR(AbsESPnetModel):
+class UniASR(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
@@ -179,7 +179,6 @@
decoding_ind: int = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -469,7 +468,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -533,7 +531,6 @@
ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -627,9 +624,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -666,7 +661,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -1072,4 +1066,3 @@
ys_hat = self.ctc2.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
-
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index d72c635..82d8422 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -469,7 +469,7 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
+
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -499,11 +499,11 @@
return segments, in_cache
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
- is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+ is_final: bool = False, max_end_sil: int = 800
+ ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
-
+
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
diff --git a/funasr/models/encoder/abs_encoder.py b/funasr/models/encoder/abs_encoder.py
index 1fb7c97..034bc1f 100644
--- a/funasr/models/encoder/abs_encoder.py
+++ b/funasr/models/encoder/abs_encoder.py
@@ -18,4 +18,4 @@
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 434f2a4..5f20dee 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -15,13 +15,13 @@
from typeguard import check_argument_types
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
+from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
@@ -1078,7 +1078,7 @@
limit_size,
)
- mask = make_source_mask(x_len)
+ mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py
index fd1796c..64c2144 100644
--- a/funasr/models/encoder/data2vec_encoder.py
+++ b/funasr/models/encoder/data2vec_encoder.py
@@ -574,4 +574,4 @@
)
def output_size(self) -> int:
- return self.encoder_embed_dim
+ return self.encoder_embed_dim
\ No newline at end of file
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 83d0b0e..95ccf07 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -38,13 +38,12 @@
import pdb
import math
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
-
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
-
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
@@ -83,13 +82,10 @@
def forward(self, x):
"""Compute convolution module.
-
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
-
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
-
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
@@ -107,10 +103,8 @@
return x.transpose(1, 2)
-
class MFCCAEncoder(AbsEncoder):
"""Conformer encoder module.
-
Args:
input_size (int): Input dimension.
output_size (int): Dimention of attention.
@@ -140,33 +134,32 @@
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
-
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- cnn_module_kernel: int = 31,
- padding_idx: int = -1,
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ cnn_module_kernel: int = 31,
+ padding_idx: int = -1,
):
assert check_argument_types()
super().__init__()
@@ -199,7 +192,7 @@
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
-
+
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
@@ -283,7 +276,7 @@
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
- attention_heads,
+ attention_heads,
output_size,
attention_dropout_rate,
)
@@ -326,42 +319,39 @@
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
- self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
+ self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
+ self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
+ self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
+ self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3))
def output_size(self) -> int:
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- channel_size: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ channel_size: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
-
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
-
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
-
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -380,48 +370,46 @@
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
- xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
- #pdb.set_trace()
- if(channel_size<8):
- repeat_num = math.ceil(8/channel_size)
- xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
+ xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
+ # pdb.set_trace()
+ if (channel_size < 8):
+ repeat_num = math.ceil(8 / channel_size)
+ xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
xs_pad = self.conv1(xs_pad)
xs_pad = self.conv2(xs_pad)
xs_pad = self.conv3(xs_pad)
xs_pad = self.conv4(xs_pad)
- xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
+ xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim)
mask_tmp = masks.size(1)
- masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
+ masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
+
def forward_hidden(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
-
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
-
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
-
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -447,4 +435,4 @@
self.hidden_feature = self.after_norm(hidden_feature)
olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
+ return xs_pad, olens, None
\ No newline at end of file
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 93695c8..8445feb 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -850,4 +850,4 @@
else:
logging.warning("{} is missed from tf checkpoint".format(name))
- return var_dict_torch_update
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py
index 7a3b053..59730da 100644
--- a/funasr/models/encoder/rnn_encoder.py
+++ b/funasr/models/encoder/rnn_encoder.py
@@ -1,3 +1,4 @@
+
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -14,7 +15,6 @@
class RNNEncoder(AbsEncoder):
"""RNNEncoder class.
-
Args:
input_size: The number of expected features in the input
output_size: The number of output features
@@ -23,7 +23,6 @@
use_projection: Use projection layer or not
num_layers: Number of recurrent layers
dropout: dropout probability
-
"""
def __init__(
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a68011..da67586 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -27,9 +27,10 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -354,18 +355,9 @@
def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
- # process last chunk
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- if cache["is_final"]:
- cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
- if not cache["last_chunk"]:
- padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
- overlap_feats = overlap_feats.transpose(1, 2)
- overlap_feats = F.pad(overlap_feats, (0, padding_length))
- overlap_feats = overlap_feats.transpose(1, 2)
- else:
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
return overlap_feats
def forward_chunk(self,
diff --git a/funasr/models/frontend/abs_frontend.py b/funasr/models/frontend/abs_frontend.py
index 538236f..6049a01 100644
--- a/funasr/models/frontend/abs_frontend.py
+++ b/funasr/models/frontend/abs_frontend.py
@@ -14,4 +14,4 @@
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index c4dd7c5..19994f0 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -18,7 +18,6 @@
class DefaultFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
-
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -142,7 +141,6 @@
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
-
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -260,4 +258,4 @@
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
- return input_stft, feats_lens
+ return input_stft, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/fused.py b/funasr/models/frontend/fused.py
index 8b5e56e..857486d 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/models/frontend/fused.py
@@ -143,4 +143,4 @@
else:
raise NotImplementedError
- return input_feats, feats_lens
+ return input_feats, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py
index f2b6107..b03d2c9 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/models/frontend/s3prl.py
@@ -100,7 +100,6 @@
def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
-
Input - sequence of representations
shape: (batch_size, seq_len, feature_dim)
Output - sequence of tiled representations
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 1dbf490..35fab57 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -500,4 +500,4 @@
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
- return feats_pad, feats_lens
+ return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index b91ac63..5372de3 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -1,15 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
-from typing import Tuple
-
import numpy as np
import torch
-import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
-from torch.nn.utils.rnn import pad_sequence
-# import kaldifeat
+
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
@@ -33,9 +27,9 @@
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
- return cmvn
-
+ cmvn = torch.as_tensor(cmvn)
+ return cmvn
+
def apply_cmvn(inputs, cmvn_file): # noqa
"""
@@ -73,108 +67,3 @@
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
-
-
-# class WavFrontend_kaldifeat(AbsFrontend):
-# """Conventional frontend structure for ASR.
-# """
-#
-# def __init__(
-# self,
-# cmvn_file: str = None,
-# fs: int = 16000,
-# window: str = 'hamming',
-# n_mels: int = 80,
-# frame_length: int = 25,
-# frame_shift: int = 10,
-# lfr_m: int = 1,
-# lfr_n: int = 1,
-# dither: float = 1.0,
-# snip_edges: bool = True,
-# upsacle_samples: bool = True,
-# device: str = 'cpu',
-# **kwargs,
-# ):
-# super().__init__()
-#
-# opts = kaldifeat.FbankOptions()
-# opts.device = device
-# opts.frame_opts.samp_freq = fs
-# opts.frame_opts.dither = dither
-# opts.frame_opts.window_type = window
-# opts.frame_opts.frame_shift_ms = float(frame_shift)
-# opts.frame_opts.frame_length_ms = float(frame_length)
-# opts.mel_opts.num_bins = n_mels
-# opts.energy_floor = 0
-# opts.frame_opts.snip_edges = snip_edges
-# opts.mel_opts.debug_mel = False
-# self.opts = opts
-# self.fbank_fn = None
-# self.fbank_beg_idx = 0
-# self.reset_fbank_status()
-#
-# self.lfr_m = lfr_m
-# self.lfr_n = lfr_n
-# self.cmvn_file = cmvn_file
-# self.upsacle_samples = upsacle_samples
-#
-# def output_size(self) -> int:
-# return self.n_mels * self.lfr_m
-#
-# def forward_fbank(
-# self,
-# input: torch.Tensor,
-# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-# batch_size = input.size(0)
-# feats = []
-# feats_lens = []
-# for i in range(batch_size):
-# waveform_length = input_lengths[i]
-# waveform = input[i][:waveform_length]
-# waveform = waveform * (1 << 15)
-#
-# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
-# frames = self.fbank_fn.num_frames_ready
-# frames_cur = frames - self.fbank_beg_idx
-# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
-# device=self.opts.device)
-# for i in range(self.fbank_beg_idx, frames):
-# mat[i, :] = self.fbank_fn.get_frame(i)
-# self.fbank_beg_idx += frames_cur
-#
-# feat_length = mat.size(0)
-# feats.append(mat)
-# feats_lens.append(feat_length)
-#
-# feats_lens = torch.as_tensor(feats_lens)
-# feats_pad = pad_sequence(feats,
-# batch_first=True,
-# padding_value=0.0)
-# return feats_pad, feats_lens
-#
-# def reset_fbank_status(self):
-# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
-# self.fbank_beg_idx = 0
-#
-# def forward_lfr_cmvn(
-# self,
-# input: torch.Tensor,
-# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-# batch_size = input.size(0)
-# feats = []
-# feats_lens = []
-# for i in range(batch_size):
-# mat = input[i, :input_lengths[i], :]
-# if self.lfr_m != 1 or self.lfr_n != 1:
-# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-# if self.cmvn_file is not None:
-# mat = apply_cmvn(mat, self.cmvn_file)
-# feat_length = mat.size(0)
-# feats.append(mat)
-# feats_lens.append(feat_length)
-#
-# feats_lens = torch.as_tensor(feats_lens)
-# feats_pad = pad_sequence(feats,
-# batch_first=True,
-# padding_value=0.0)
-# return feats_pad, feats_lens
diff --git a/funasr/models/frontend/windowing.py b/funasr/models/frontend/windowing.py
index 7c4c568..a526758 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/models/frontend/windowing.py
@@ -12,12 +12,10 @@
class SlidingWindow(AbsFrontend):
"""Sliding Window.
-
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
-
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING: trailing values are discarded - padding not implemented yet.
@@ -33,7 +31,6 @@
fs=None,
):
"""Initialize.
-
Args:
win_length: Length of frame.
hop_length: Relative starting point of next frame.
@@ -53,11 +50,9 @@
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
-
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
-
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
@@ -78,4 +73,4 @@
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
- return self.win_length
+ return self.win_length
\ No newline at end of file
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index c59e245..3c363db 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -221,13 +221,14 @@
if cache is not None and "chunk_size" in cache:
alphas[:, :cache["chunk_size"][0]] = 0.0
- alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
+ if "is_final" in cache and not cache["is_final"]:
+ alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
- if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
+ if cache is not None and "is_final" in cache and cache["is_final"]:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
diff --git a/funasr/lm/seq_rnn_lm.py b/funasr/models/seq_rnn_lm.py
similarity index 98%
rename from funasr/lm/seq_rnn_lm.py
rename to funasr/models/seq_rnn_lm.py
index 09d1e4a..f7ddcae 100644
--- a/funasr/lm/seq_rnn_lm.py
+++ b/funasr/models/seq_rnn_lm.py
@@ -5,8 +5,7 @@
import torch
import torch.nn as nn
from typeguard import check_argument_types
-
-from funasr.lm.abs_model import AbsLM
+from funasr.train.abs_model import AbsLM
class SequentialRNNLM(AbsLM):
diff --git a/funasr/models/specaug/abs_specaug.py b/funasr/models/specaug/abs_specaug.py
index 3cbac41..da6637e 100644
--- a/funasr/models/specaug/abs_specaug.py
+++ b/funasr/models/specaug/abs_specaug.py
@@ -6,9 +6,7 @@
class AbsSpecAug(torch.nn.Module):
"""Abstract class for the augmentation of spectrogram
-
The process-flow:
-
Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
"""
diff --git a/funasr/models/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
index e893c65..19e5c7c 100644
--- a/funasr/models/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -6,13 +6,10 @@
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
-#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
-#from funasr.modules.mask import subsequent_n_mask
-from funasr.train.abs_model import AbsPunctuation
-class TargetDelayTransformer(AbsPunctuation):
+class TargetDelayTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
diff --git a/funasr/lm/transformer_lm.py b/funasr/models/transformer_lm.py
similarity index 98%
rename from funasr/lm/transformer_lm.py
rename to funasr/models/transformer_lm.py
index 52af45b..1cd76dc 100644
--- a/funasr/lm/transformer_lm.py
+++ b/funasr/models/transformer_lm.py
@@ -8,7 +8,7 @@
from funasr.modules.embedding import PositionalEncoding
from funasr.models.encoder.transformer_encoder import TransformerEncoder_s0 as Encoder
from funasr.modules.mask import subsequent_mask
-from funasr.lm.abs_model import AbsLM
+from funasr.train.abs_model import AbsLM
class TransformerLM(AbsLM):
diff --git a/funasr/models/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
index fe298ce..e2d13f9 100644
--- a/funasr/models/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -7,10 +7,9 @@
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.train.abs_model import AbsPunctuation
-class VadRealtimeTransformer(AbsPunctuation):
+class VadRealtimeTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 8890714..777de4f 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -186,11 +186,12 @@
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
text_length = len(mini_sentence_id)
+ vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
- "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
- "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+ "vad_mask": vad_mask,
+ "sub_masks": vad_mask
}
try:
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 55a5d79..fd4e190 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,6 +30,7 @@
import torch.nn
import torch.optim
import yaml
+from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -44,19 +45,18 @@
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
from funasr.main_funcs.collect_stats import collect_stats
-from funasr.optimizers.sgd import SGD
from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
from funasr.samplers.build_batch_sampler import BATCH_TYPES
from funasr.samplers.build_batch_sampler import build_batch_sampler
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.warmup_lr import WarmupLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.distributed_utils import DistributedOption
from funasr.train.trainer import Trainer
@@ -230,8 +230,8 @@
>>> cls.check_task_requirements()
If your model is defined as following,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
+ >>> from funasr.models.base_model import FunASRModel
+ >>> class Model(FunASRModel):
... def forward(self, input, output, opt=None): pass
then "required_data_names" should be as
@@ -251,8 +251,8 @@
>>> cls.check_task_requirements()
If your model is defined as follows,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
+ >>> from funasr.models.base_model import FunASRModel
+ >>> class Model(FunASRModel):
... def forward(self, input, output, opt=None): pass
then "optional_data_names" should be as
@@ -263,8 +263,9 @@
@classmethod
@abstractmethod
- def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+ def build_model(cls, args: argparse.Namespace) -> FunASRModel:
raise NotImplementedError
+
@classmethod
def get_parser(cls) -> config_argparse.ArgumentParser:
@@ -1172,7 +1173,8 @@
args.batch_bins = args.batch_bins * args.ngpu
# filter samples if wav.scp and text are mismatch
- if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+ if (
+ args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
if not args.simple_ddp or distributed_option.dist_rank == 0:
filter_wav_text(args.data_dir, args.train_set)
filter_wav_text(args.data_dir, args.dev_set)
@@ -1181,8 +1183,10 @@
if args.train_shape_file is None and args.dataset_type == "small":
if not args.simple_ddp or distributed_option.dist_rank == 0:
- calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
- calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+ calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+ args.speech_length_max)
+ calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+ args.speech_length_max)
if args.simple_ddp:
dist.barrier()
args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
@@ -1244,9 +1248,9 @@
# 2. Build model
model = cls.build_model(args=args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model = model.to(
dtype=getattr(torch, args.train_dtype),
@@ -1374,15 +1378,21 @@
if args.dataset_type == "large":
from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
- frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
- seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
- punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+ frontend_conf=args.frontend_conf if hasattr(args,
+ "frontend_conf") else None,
+ seg_dict_file=args.seg_dict_file if hasattr(args,
+ "seg_dict_file") else None,
+ punc_dict_file=args.punc_list if hasattr(args,
+ "punc_list") else None,
bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
mode="train")
- valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
- frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
- seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
- punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+ valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+ frontend_conf=args.frontend_conf if hasattr(args,
+ "frontend_conf") else None,
+ seg_dict_file=args.seg_dict_file if hasattr(args,
+ "seg_dict_file") else None,
+ punc_dict_file=args.punc_list if hasattr(args,
+ "punc_list") else None,
bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
mode="eval")
elif args.dataset_type == "small":
@@ -1929,7 +1939,7 @@
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
- ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+ ) -> Tuple[FunASRModel, argparse.Namespace]:
"""Build model from the files.
This method is used for inference or fine-tuning.
@@ -1956,9 +1966,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
if model_file is not None:
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 43ea5ab..8e4f9cc 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,9 +38,9 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
@@ -76,7 +76,7 @@
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -122,7 +122,7 @@
model_choices = ClassChoices(
"model",
classes=dict(
- asr=ESPnetASRModel,
+ asr=ASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_online=ParaformerOnline,
@@ -132,8 +132,10 @@
neatcontextual_paraformer=NeatContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
+ rnnt=TransducerModel,
+ rnnt_unified=UnifiedTransducerModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="asr",
)
preencoder_choices = ClassChoices(
@@ -224,6 +226,15 @@
default="rnnt",
)
+joint_network_choices = ClassChoices(
+ name="joint_network",
+ classes=dict(
+ joint_network=JointNetwork,
+ ),
+ default="joint_network",
+ optional=True,
+)
+
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
@@ -280,6 +291,18 @@
postencoder_choices,
# --decoder and --decoder_conf
decoder_choices,
+ # --predictor and --predictor_conf
+ predictor_choices,
+ # --encoder2 and --encoder2_conf
+ encoder_choices2,
+ # --decoder2 and --decoder2_conf
+ decoder_choices2,
+ # --predictor2 and --predictor2_conf
+ predictor_choices2,
+ # --stride_conv and --stride_conv_conf
+ stride_conv_choices,
+ # --rnnt_decoder and --rnnt_decoder_conf
+ rnnt_decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
@@ -341,7 +364,7 @@
help="The keyword arguments for CTC class.",
)
group.add_argument(
- "--joint_net_conf",
+ "--joint_network_conf",
action=NestedDictAction,
default=None,
help="The keyword arguments for joint network class.",
@@ -457,7 +480,7 @@
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
- non_linguistic_symbols=args.non_linguistic_symbols,
+ non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
@@ -827,9 +850,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -899,27 +922,27 @@
# If you need more than one optimizers, change this value
num_optimizers: int = 1
- # Add variable objects configurations
- class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --model and --model_conf
- model_choices,
- # --preencoder and --preencoder_conf
- preencoder_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --postencoder and --postencoder_conf
- postencoder_choices,
- # --decoder and --decoder_conf
- decoder_choices,
- # --predictor and --predictor_conf
- predictor_choices,
- ]
+ # # Add variable objects configurations
+ # class_choices_list = [
+ # # --frontend and --frontend_conf
+ # frontend_choices,
+ # # --specaug and --specaug_conf
+ # specaug_choices,
+ # # --normalize and --normalize_conf
+ # normalize_choices,
+ # # --model and --model_conf
+ # model_choices,
+ # # --preencoder and --preencoder_conf
+ # preencoder_choices,
+ # # --encoder and --encoder_conf
+ # encoder_choices,
+ # # --postencoder and --postencoder_conf
+ # postencoder_choices,
+ # # --decoder and --decoder_conf
+ # decoder_choices,
+ # # --predictor and --predictor_conf
+ # predictor_choices,
+ # ]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@@ -1074,9 +1097,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -1350,7 +1373,7 @@
return retval
-class ASRTransducerTask(AbsTask):
+class ASRTransducerTask(ASRTask):
"""ASR Transducer Task definition."""
num_optimizers: int = 1
@@ -1361,243 +1384,10 @@
normalize_choices,
encoder_choices,
rnnt_decoder_choices,
+ joint_network_choices,
]
trainer = Trainer
-
- @classmethod
- def add_task_arguments(cls, parser: argparse.ArgumentParser):
- """Add Transducer task arguments.
- Args:
- cls: ASRTransducerTask object.
- parser: Transducer arguments parser.
- """
- group = parser.add_argument_group(description="Task related.")
-
- # required = parser.get_default("required")
- # required += ["token_list"]
-
- group.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="Integer-string mapper for tokens.",
- )
- group.add_argument(
- "--split_with_space",
- type=str2bool,
- default=True,
- help="whether to split text using <space>",
- )
- group.add_argument(
- "--input_size",
- type=int_or_none,
- default=None,
- help="The number of dimensions for input features.",
- )
- group.add_argument(
- "--init",
- type=str_or_none,
- default=None,
- help="Type of model initialization to use.",
- )
- group.add_argument(
- "--model_conf",
- action=NestedDictAction,
- default=get_default_kwargs(TransducerModel),
- help="The keyword arguments for the model class.",
- )
- # group.add_argument(
- # "--encoder_conf",
- # action=NestedDictAction,
- # default={},
- # help="The keyword arguments for the encoder class.",
- # )
- group.add_argument(
- "--joint_network_conf",
- action=NestedDictAction,
- default={},
- help="The keyword arguments for the joint network class.",
- )
- group = parser.add_argument_group(description="Preprocess related.")
- group.add_argument(
- "--use_preprocessor",
- type=str2bool,
- default=True,
- help="Whether to apply preprocessing to input data.",
- )
- group.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word", "phn"],
- help="The type of tokens to use during tokenization.",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The path of the sentencepiece model.",
- )
- parser.add_argument(
- "--non_linguistic_symbols",
- type=str_or_none,
- help="The 'non_linguistic_symbols' file path.",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Text cleaner to use.",
- )
- parser.add_argument(
- "--g2p",
- type=str_or_none,
- choices=g2p_choices,
- default=None,
- help="g2p method to use if --token_type=phn.",
- )
- parser.add_argument(
- "--speech_volume_normalize",
- type=float_or_none,
- default=None,
- help="Normalization value for maximum amplitude scaling.",
- )
- parser.add_argument(
- "--rir_scp",
- type=str_or_none,
- default=None,
- help="The RIR SCP file path.",
- )
- parser.add_argument(
- "--rir_apply_prob",
- type=float,
- default=1.0,
- help="The probability of the applied RIR convolution.",
- )
- parser.add_argument(
- "--noise_scp",
- type=str_or_none,
- default=None,
- help="The path of noise SCP file.",
- )
- parser.add_argument(
- "--noise_apply_prob",
- type=float,
- default=1.0,
- help="The probability of the applied noise addition.",
- )
- parser.add_argument(
- "--noise_db_range",
- type=str,
- default="13_15",
- help="The range of the noise decibel level.",
- )
- for class_choices in cls.class_choices_list:
- # Append --<name> and --<name>_conf.
- # e.g. --decoder and --decoder_conf
- class_choices.add_arguments(group)
-
- @classmethod
- def build_collate_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Callable[
- [Collection[Tuple[str, Dict[str, np.ndarray]]]],
- Tuple[List[str], Dict[str, torch.Tensor]],
- ]:
- """Build collate function.
- Args:
- cls: ASRTransducerTask object.
- args: Task arguments.
- train: Training mode.
- Return:
- : Callable collate function.
- """
- assert check_argument_types()
-
- return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
-
- @classmethod
- def build_preprocess_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- """Build pre-processing function.
- Args:
- cls: ASRTransducerTask object.
- args: Task arguments.
- train: Training mode.
- Return:
- : Callable pre-processing function.
- """
- assert check_argument_types()
-
- if args.use_preprocessor:
- retval = CommonPreprocessor(
- train=train,
- token_type=args.token_type,
- token_list=args.token_list,
- bpemodel=args.bpemodel,
- non_linguistic_symbols=args.non_linguistic_symbols,
- text_cleaner=args.cleaner,
- g2p_type=args.g2p,
- split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
- rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
- rir_apply_prob=args.rir_apply_prob
- if hasattr(args, "rir_apply_prob")
- else 1.0,
- noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
- noise_apply_prob=args.noise_apply_prob
- if hasattr(args, "noise_apply_prob")
- else 1.0,
- noise_db_range=args.noise_db_range
- if hasattr(args, "noise_db_range")
- else "13_15",
- speech_volume_normalize=args.speech_volume_normalize
- if hasattr(args, "rir_scp")
- else None,
- )
- else:
- retval = None
-
- assert check_return_type(retval)
- return retval
-
- @classmethod
- def required_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Required data depending on task mode.
- Args:
- cls: ASRTransducerTask object.
- train: Training mode.
- inference: Inference mode.
- Return:
- retval: Required task data.
- """
- if not inference:
- retval = ("speech", "text")
- else:
- retval = ("speech",)
-
- return retval
-
- @classmethod
- def optional_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Optional data depending on task mode.
- Args:
- cls: ASRTransducerTask object.
- train: Training mode.
- inference: Inference mode.
- Return:
- retval: Optional task data.
- """
- retval = ()
- assert check_return_type(retval)
-
- return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> TransducerModel:
@@ -1665,7 +1455,7 @@
decoder_output_size = decoder.output_size
if getattr(args, "decoder", None) is not None:
- att_decoder_class = decoder_choices.get_class(args.att_decoder)
+ att_decoder_class = decoder_choices.get_class(args.decoder)
att_decoder = att_decoder_class(
vocab_size=vocab_size,
@@ -1683,35 +1473,23 @@
)
# 7. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
- if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
- model = UnifiedTransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- else:
- model = TransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
+ model = model_class(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
# 8. Initialize model
if args.init is not None:
raise NotImplementedError(
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 45e4ce7..2625fec 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -58,7 +58,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -114,7 +114,7 @@
sond=DiarSondModel,
eend_ola=DiarEENDOLAModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="sond",
)
encoder_choices = ClassChoices(
@@ -544,9 +544,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -902,9 +902,9 @@
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 80d66d5..44fdf8e 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -14,10 +14,10 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.lm.abs_model import AbsLM
-from funasr.lm.abs_model import LanguageModel
-from funasr.lm.seq_rnn_lm import SequentialRNNLM
-from funasr.lm.transformer_lm import TransformerLM
+from funasr.train.abs_model import AbsLM
+from funasr.train.abs_model import LanguageModel
+from funasr.models.seq_rnn_lm import SequentialRNNLM
+from funasr.models.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
@@ -206,6 +206,4 @@
# 3. Initialize
if args.init is not None:
initialize(model, args.init)
-
- assert check_return_type(model)
return model
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index 0170f28..a63bbe4 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,7 +14,6 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
@@ -31,7 +30,6 @@
punc_choices = ClassChoices(
"punctuation",
classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
- type_check=AbsPunctuation,
default="target_delay",
)
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index 7cfcbd0..4769758 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -70,11 +70,11 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.models.base_model import FunASRModel
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -129,7 +129,7 @@
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="asr",
)
preencoder_choices = ClassChoices(
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index 9710447..e4815da 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -25,7 +25,7 @@
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.e2e_asr import ESPnetASRModel
+from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
@@ -49,7 +49,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -94,7 +94,7 @@
classes=dict(
espnet=ESPnetSVModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="espnet",
)
preencoder_choices = ClassChoices(
@@ -488,9 +488,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index d07acf1..ec95596 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -1,77 +1,42 @@
import argparse
import logging
+import os
+from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
-import os
-from pathlib import Path
-from typing import Tuple
from typing import Union
-import yaml
+
import numpy as np
import torch
+import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolutionTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
- HuggingFaceTransformersPostEncoder, # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
-from funasr.modules.subsampling import Conv1dSubsampling
from funasr.models.e2e_vad import E2EVadModel
from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.tasks.abs_task import AbsTask
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
@@ -292,7 +257,7 @@
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
-
+
# 1. frontend
if args.input_size is None:
# Extract features in the model
@@ -308,7 +273,7 @@
args.frontend_conf = {}
frontend = None
input_size = args.input_size
-
+
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
@@ -344,7 +309,7 @@
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
- #if cmvn_file is not None:
+ # if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
diff --git a/funasr/train/abs_espnet_model.py b/funasr/train/abs_espnet_model.py
deleted file mode 100644
index cc6a5a2..0000000
--- a/funasr/train/abs_espnet_model.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-from abc import ABC
-from abc import abstractmethod
-from typing import Dict
-from typing import Tuple
-
-import torch
-
-
-class AbsESPnetModel(torch.nn.Module, ABC):
- """The common abstract class among each tasks
-
- "ESPnetModel" is referred to a class which inherits torch.nn.Module,
- and makes the dnn-models forward as its member field,
- a.k.a delegate pattern,
- and defines "loss", "stats", and "weight" for the task.
-
- If you intend to implement new task in ESPNet,
- the model must inherit this class.
- In other words, the "mediator" objects between
- our training system and the your task class are
- just only these three values, loss, stats, and weight.
-
- Example:
- >>> from funasr.tasks.abs_task import AbsTask
- >>> class YourESPnetModel(AbsESPnetModel):
- ... def forward(self, input, input_lengths):
- ... ...
- ... return loss, stats, weight
- >>> class YourTask(AbsTask):
- ... @classmethod
- ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
- """
-
- def __init__(self):
- super().__init__()
- self.num_updates = 0
-
- @abstractmethod
- def forward(
- self, **batch: torch.Tensor
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
- raise NotImplementedError
-
- def set_num_updates(self, num_updates):
- self.num_updates = num_updates
-
- def get_num_updates(self):
- return self.num_updates
diff --git a/funasr/train/abs_model.py b/funasr/train/abs_model.py
index 1c7ff3d..8d684be 100644
--- a/funasr/train/abs_model.py
+++ b/funasr/train/abs_model.py
@@ -1,7 +1,7 @@
from abc import ABC
from abc import abstractmethod
-
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -12,13 +12,10 @@
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-
-
-class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
- """The abstract class
+class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
+ """The abstract LM class
To share the loss calculation way among different models,
We uses delegate pattern here:
@@ -29,17 +26,134 @@
"""
@abstractmethod
- def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def with_vad(self) -> bool:
+ def forward(
+ self, input: torch.Tensor, hidden: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
-class PunctuationModel(AbsESPnetModel):
+class LanguageModel(FunASRModel):
+ def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
+ assert check_argument_types()
+ super().__init__()
+ self.lm = lm
+ self.sos = 1
+ self.eos = 2
+
+ # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
+ self.ignore_id = ignore_id
+
+ def nll(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ max_length: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute negative log likelihood(nll)
+
+ Normally, this function is called in batchify_nll.
+ Args:
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ max_lengths: int
+ """
+ batch_size = text.size(0)
+ # For data parallel
+ if max_length is None:
+ text = text[:, : text_lengths.max()]
+ else:
+ text = text[:, :max_length]
+
+ # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
+ # text: (Batch, Length) -> x, y: (Batch, Length + 1)
+ x = F.pad(text, [1, 0], "constant", self.sos)
+ t = F.pad(text, [0, 1], "constant", self.ignore_id)
+ for i, l in enumerate(text_lengths):
+ t[i, l] = self.eos
+ x_lengths = text_lengths + 1
+
+ # 2. Forward Language model
+ # x: (Batch, Length) -> y: (Batch, Length, NVocab)
+ y, _ = self.lm(x, None)
+
+ # 3. Calc negative log likelihood
+ # nll: (BxL,)
+ nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
+ # nll: (BxL,) -> (BxL,)
+ if max_length is None:
+ nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
+ else:
+ nll.masked_fill_(
+ make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
+ 0.0,
+ )
+ # nll: (BxL,) -> (B, L)
+ nll = nll.view(batch_size, -1)
+ return nll, x_lengths
+
+ def batchify_nll(
+ self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute negative log likelihood(nll) from transformer language model
+
+ To avoid OOM, this fuction seperate the input into batches.
+ Then call nll for each batch and combine and return results.
+ Args:
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ batch_size: int, samples each batch contain when computing nll,
+ you may change this to avoid OOM or increase
+
+ """
+ total_num = text.size(0)
+ if total_num <= batch_size:
+ nll, x_lengths = self.nll(text, text_lengths)
+ else:
+ nlls = []
+ x_lengths = []
+ max_length = text_lengths.max()
+
+ start_idx = 0
+ while True:
+ end_idx = min(start_idx + batch_size, total_num)
+ batch_text = text[start_idx:end_idx, :]
+ batch_text_lengths = text_lengths[start_idx:end_idx]
+ # batch_nll: [B * T]
+ batch_nll, batch_x_lengths = self.nll(
+ batch_text, batch_text_lengths, max_length=max_length
+ )
+ nlls.append(batch_nll)
+ x_lengths.append(batch_x_lengths)
+ start_idx = end_idx
+ if start_idx == total_num:
+ break
+ nll = torch.cat(nlls)
+ x_lengths = torch.cat(x_lengths)
+ assert nll.size(0) == total_num
+ assert x_lengths.size(0) == total_num
+ return nll, x_lengths
+
+ def forward(
+ self, text: torch.Tensor, text_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ nll, y_lengths = self.nll(text, text_lengths)
+ ntokens = y_lengths.sum()
+ loss = nll.sum() / ntokens
+ stats = dict(loss=loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self, text: torch.Tensor, text_lengths: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ return {}
+
+
+class PunctuationModel(FunASRModel):
- def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
+ def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
self.punc_model = punc_model
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index a40f031..4052448 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -39,7 +39,7 @@
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
@@ -166,7 +166,7 @@
@classmethod
def run(
cls,
- model: AbsESPnetModel,
+ model: FunASRModel,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
train_iter_factory: AbsIterFactory,
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
new file mode 100644
index 0000000..36795b4
--- /dev/null
+++ b/funasr/utils/prepare_data.py
@@ -0,0 +1,226 @@
+import logging
+import os
+import shutil
+from multiprocessing import Pool
+
+import kaldiio
+import numpy as np
+import torch.distributed as dist
+import torchaudio
+
+
+def filter_wav_text(data_dir, dataset):
+ wav_file = os.path.join(data_dir, dataset, "wav.scp")
+ text_file = os.path.join(data_dir, dataset, "text")
+ with open(wav_file) as f_wav, open(text_file) as f_text:
+ wav_lines = f_wav.readlines()
+ text_lines = f_text.readlines()
+ os.rename(wav_file, "{}.bak".format(wav_file))
+ os.rename(text_file, "{}.bak".format(text_file))
+ wav_dict = {}
+ for line in wav_lines:
+ parts = line.strip().split()
+ if len(parts) < 2:
+ continue
+ wav_dict[parts[0]] = parts[1]
+ text_dict = {}
+ for line in text_lines:
+ parts = line.strip().split()
+ if len(parts) < 2:
+ continue
+ text_dict[parts[0]] = " ".join(parts[1:])
+ filter_count = 0
+ with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
+ for sample_name, wav_path in wav_dict.items():
+ if sample_name in text_dict.keys():
+ f_wav.write(sample_name + " " + wav_path + "\n")
+ f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
+ else:
+ filter_count += 1
+ logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".
+ format(filter_count, len(wav_lines), dataset))
+
+
+def wav2num_frame(wav_path, frontend_conf):
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
+ feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
+ return n_frames, feature_dim
+
+
+def calc_shape_core(root_path, args, idx):
+ file_name = args.data_file_names.split(",")[0]
+ data_name = args.dataset_conf.get("data_names", "speech,text").split(",")[0]
+ scp_file = os.path.join(root_path, "{}.{}".format(file_name, idx))
+ shape_file = os.path.join(root_path, "{}_shape.{}".format(data_name, idx))
+ with open(scp_file) as f:
+ lines = f.readlines()
+ data_type = args.dataset_conf.get("data_types", "sound,text").split(",")[0]
+ if data_type == "sound":
+ frontend_conf = args.frontend_conf
+ dataset_conf = args.dataset_conf
+ length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "{}_length_min".format(data_name)) else -1
+ length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "{}_length_max".format(data_name)) else -1
+ with open(shape_file, "w") as f:
+ for line in lines:
+ sample_name, wav_path = line.strip().split()
+ n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
+ write_flag = True
+ if n_frames > 0 and length_min > 0:
+ write_flag = n_frames >= length_min
+ if n_frames > 0 and length_max > 0:
+ write_flag = n_frames <= length_max
+ if write_flag:
+ f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
+ f.flush()
+ elif data_type == "kaldi_ark":
+ dataset_conf = args.dataset_conf
+ length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "{}_length_min".format(data_name)) else -1
+ length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "{}_length_max".format(data_name)) else -1
+ with open(shape_file, "w") as f:
+ for line in lines:
+ sample_name, feature_path = line.strip().split()
+ feature = kaldiio.load_mat(feature_path)
+ n_frames, feature_dim = feature.shape
+ if n_frames > 0 and length_min > 0:
+ write_flag = n_frames >= length_min
+ if n_frames > 0 and length_max > 0:
+ write_flag = n_frames <= length_max
+ if write_flag:
+ f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
+ f.flush()
+ elif data_type == "text":
+ with open(shape_file, "w") as f:
+ for line in lines:
+ sample_name, text = line.strip().split(maxsplit=1)
+ n_tokens = len(text.split())
+ f.write("{} {}\n".format(sample_name, str(int(np.ceil(n_tokens)))))
+ f.flush()
+ else:
+ raise RuntimeError("Unsupported data_type: {}".format(data_type))
+
+
+def calc_shape(args, dataset, nj=64):
+ data_name = args.dataset_conf.get("data_names", "speech,text").split(",")[0]
+ shape_path = os.path.join(args.data_dir, dataset, "{}_shape".format(data_name))
+ if os.path.exists(shape_path):
+ logging.info('Shape file for small dataset already exists.')
+ return
+
+ split_shape_path = os.path.join(args.data_dir, dataset, "{}_shape_files".format(data_name))
+ if os.path.exists(split_shape_path):
+ shutil.rmtree(split_shape_path)
+ os.mkdir(split_shape_path)
+
+ # split
+ file_name = args.data_file_names.split(",")[0]
+ scp_file = os.path.join(args.data_dir, dataset, file_name)
+ with open(scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_shape_path, "{}.{}".format(file_name, str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
+ logging.info("Generating shape files, please wait a few minutes...")
+ p.close()
+ p.join()
+
+ # combine
+ with open(shape_path, "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(split_shape_path, "{}_shape.{}".format(data_name, str(i + 1)))
+ with open(job_file) as job_f:
+ lines = job_f.readlines()
+ f.writelines(lines)
+ logging.info('Generating shape files done.')
+
+
+def generate_data_list(args, data_dir, dataset, nj=64):
+ data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
+ file_names = args.data_file_names.split(",")
+ concat_data_name = "_".join(data_names)
+ list_file = os.path.join(data_dir, dataset, "{}_data.list".format(concat_data_name))
+ if os.path.exists(list_file):
+ logging.info('Data list for large dataset already exists.')
+ return
+ split_path = os.path.join(data_dir, dataset, "split")
+ if os.path.exists(split_path):
+ shutil.rmtree(split_path)
+ os.mkdir(split_path)
+
+ data_lines_list = []
+ for file_name in file_names:
+ with open(os.path.join(data_dir, dataset, file_name)) as f:
+ lines = f.readlines()
+ data_lines_list.append(lines)
+ num_lines = len(data_lines_list[0])
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ split_path_nj = os.path.join(split_path, str(i + 1))
+ os.mkdir(split_path_nj)
+ for file_id, file_name in enumerate(file_names):
+ file = os.path.join(split_path_nj, file_name)
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(data_lines_list[file_id][start:])
+ else:
+ f.writelines(data_lines_list[file_id][start:end])
+ start = end
+
+ with open(list_file, "w") as f_data:
+ for i in range(nj):
+ path = ""
+ for file_name in file_names:
+ path = path + os.path.join(split_path, str(i + 1), file_name)
+ f_data.write(path + "\n")
+
+
+def prepare_data(args, distributed_option):
+ distributed = distributed_option.distributed
+ if not distributed or distributed_option.dist_rank == 0:
+ if hasattr(args, "filter_input") and args.filter_input:
+ filter_wav_text(args.data_dir, args.train_set)
+ filter_wav_text(args.data_dir, args.valid_set)
+
+ if args.dataset_type == "small":
+ calc_shape(args, args.train_set)
+ calc_shape(args, args.valid_set)
+
+ if args.dataset_type == "large":
+ generate_data_list(args, args.data_dir, args.train_set)
+ generate_data_list(args, args.data_dir, args.valid_set)
+
+ data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
+ data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
+ file_names = args.data_file_names.split(",")
+ print("data_names: {}, data_types: {}, file_names: {}".format(data_names, data_types, file_names))
+ assert len(data_names) == len(data_types) == len(file_names)
+ if args.dataset_type == "small":
+ args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "{}_shape".format(data_names[0]))]
+ args.valid_shape_file = [os.path.join(args.data_dir, args.valid_set, "{}_shape".format(data_names[0]))]
+ args.train_data_path_and_name_and_type, args.valid_data_path_and_name_and_type = [], []
+ for file_name, data_name, data_type in zip(file_names, data_names, data_types):
+ args.train_data_path_and_name_and_type.append(
+ ["{}/{}/{}".format(args.data_dir, args.train_set, file_name), data_name, data_type])
+ args.valid_data_path_and_name_and_type.append(
+ ["{}/{}/{}".format(args.data_dir, args.valid_set, file_name), data_name, data_type])
+ else:
+ concat_data_name = "_".join(data_names)
+ args.train_data_file = os.path.join(args.data_dir, args.train_set, "{}_data.list".format(concat_data_name))
+ args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "{}_data.list".format(concat_data_name))
+ if distributed:
+ dist.barrier()
diff --git a/funasr/version.txt b/funasr/version.txt
index 4b9fcbe..cb0c939 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.5.1
+0.5.2
--
Gitblit v1.9.1