From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update
---
examples/industrial_data_pretraining/paraformer/README.md | 424 +++++++++++++++++++++
examples/industrial_data_pretraining/paraformer-zh-spk/README_zh.md | 18
funasr/train_utils/trainer.py | 82 ++-
examples/README.md | 424 +++++++++++++++++++++
funasr/bin/train.py | 2
funasr/train_utils/average_nbest_models.py | 137 -----
docs/tutorial/README_zh.md | 18
examples/industrial_data_pretraining/paraformer/README_zh.md | 18
examples/industrial_data_pretraining/paraformer_streaming/README_zh.md | 18
examples/README_zh.md | 18
README_zh.md | 4
README.md | 4
12 files changed, 979 insertions(+), 188 deletions(-)
diff --git a/README.md b/README.md
index 2409fe5..9c3e00d 100644
--- a/README.md
+++ b/README.md
@@ -208,8 +208,8 @@
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
-
-More examples ref to [docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
+More usages ref to [docs](docs/tutorial/README_zh.md),
+more examples ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
## Export ONNX
diff --git a/README_zh.md b/README_zh.md
index f32e76a..65029a1 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -208,8 +208,8 @@
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
-鏇磋缁嗭紙[鐢ㄦ硶](docs/tutorial/README_zh.md)锛夛紝
-鏇村锛圼绀轰緥](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)锛�
+鏇磋缁嗭紙[鏁欑▼鏂囨。](docs/tutorial/README_zh.md)锛夛紝
+鏇村锛圼妯″瀷绀轰緥](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)锛�
## 瀵煎嚭ONNX
### 浠庡懡浠よ瀵煎嚭
diff --git a/docs/tutorial/README_zh.md b/docs/tutorial/README_zh.md
index fa85290..4e9bb3f 100644
--- a/docs/tutorial/README_zh.md
+++ b/docs/tutorial/README_zh.md
@@ -235,13 +235,17 @@
- `valid_data_set_list`锛坰tr锛夛細楠岃瘉鏁版嵁璺緞锛岄粯璁や负jsonl鏍煎紡锛屽叿浣撳弬鑰冿紙[渚嬪瓙](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛夈��
- `dataset_conf.batch_type`锛坰tr锛夛細`example`锛堥粯璁わ級锛宐atch鐨勭被鍨嬨�俙example`琛ㄧず鎸夌収鍥哄畾鏁扮洰batch_size涓牱鏈粍batch锛沗length` or `token` 琛ㄧず鍔ㄦ�佺粍batch锛宐atch鎬婚暱搴︽垨鑰卼oken鏁颁负batch_size銆�
- `dataset_conf.batch_size`锛坕nt锛夛細涓� `batch_type` 鎼厤浣跨敤锛屽綋 `batch_type=example` 鏃讹紝琛ㄧず鏍锋湰涓暟锛涘綋 `batch_type=length` 鏃讹紝琛ㄧず鏍锋湰涓暱搴︼紝鍗曚綅涓篺bank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
-- `train_conf.max_epoch`锛坕nt锛夛細璁粌鎬籩poch鏁般��
-- `train_conf.log_interval`锛坕nt锛夛細鎵撳嵃鏃ュ織闂撮殧step鏁般��
-- `train_conf.resume`锛坕nt锛夛細鏄惁寮�鍚柇鐐归噸璁��
-- `train_conf.validate_interval`锛坕nt锛夛細璁粌涓仛楠岃瘉娴嬭瘯鐨勯棿闅攕tep鏁般��
-- `train_conf.save_checkpoint_interval`锛坕nt锛夛細璁粌涓ā鍨嬩繚瀛橀棿闅攕tep鏁般��
-- `train_conf.keep_nbest_models`锛坕nt锛夛細淇濈暀鏈�澶у灏戜釜妯″瀷鍙傛暟锛屾寜鐓ч獙璇侀泦acc鎺掑簭锛屼粠楂樺埌搴曚繚鐣欍��
-- `train_conf.avg_nbest_model`锛坕nt锛夛細瀵筧cc鏈�楂樼殑n涓ā鍨嬪彇骞冲潎銆�
+- `train_conf.max_epoch`锛坕nt锛夛細`100`锛堥粯璁わ級锛岃缁冩�籩poch鏁般��
+- `train_conf.log_interval`锛坕nt锛夛細`50`锛堥粯璁わ級锛屾墦鍗版棩蹇楅棿闅攕tep鏁般��
+- `train_conf.resume`锛坕nt锛夛細`True`锛堥粯璁わ級锛屾槸鍚﹀紑鍚柇鐐归噸璁��
+- `train_conf.validate_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑鍋氶獙璇佹祴璇曠殑闂撮殧step鏁般��
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑妯″瀷淇濆瓨闂撮殧step鏁般��
+- `train_conf.avg_keep_nbest_models_type`锛坰tr锛夛細`acc`锛堥粯璁わ級锛屼繚鐣檔best鐨勬爣鍑嗕负acc锛堣秺澶ц秺濂斤級銆俙loss`琛ㄧず锛屼繚鐣檔best鐨勬爣鍑嗕负loss锛堣秺灏忚秺濂斤級銆�
+- `train_conf.keep_nbest_models`锛坕nt锛夛細`500`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 淇濈暀鏈�浣崇殑n涓ā鍨嬶紝鍏朵粬鍒犻櫎锛岃妭绾﹀瓨鍌ㄧ┖闂淬��
+- `train_conf.avg_nbest_model`锛坕nt锛夛細`5`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 瀵规渶浣崇殑n涓ā鍨嬪钩鍧囥��
+- `train_conf.accum_grad`锛坕nt锛夛細`1`锛堥粯璁わ級锛屾搴︾疮绉姛鑳姐��
+- `train_conf.grad_clip`锛坒loat锛夛細`10.0`锛堥粯璁わ級锛屾搴︽埅鏂姛鑳姐��
+- `train_conf.use_fp16`锛坆ool锛夛細`False`锛堥粯璁わ級锛屽紑鍚痜p16璁粌锛屽姞蹇缁冮�熷害銆�
- `optim_conf.lr`锛坒loat锛夛細瀛︿範鐜囥��
- `output_dir`锛坰tr锛夛細妯″瀷淇濆瓨璺緞銆�
- `**kwargs`(dict): 鎵�鏈夊湪`config.yaml`涓弬鏁帮紝鍧囧彲浠ョ洿鎺ュ湪姝ゅ鎸囧畾锛屼緥濡傦紝杩囨护20s浠ヤ笂闀块煶棰戯細`dataset_conf.max_token_length=2000`锛屽崟浣嶄负闊抽fbank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 0000000..20102cc
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,424 @@
+([绠�浣撲腑鏂嘳(./README_zh.md)|English)
+
+FunASR has open-sourced a large number of pre-trained models on industrial data. You are free to use, copy, modify, and share FunASR models under the [Model License Agreement](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE). Below, we list some representative models. For a comprehensive list, please refer to our [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo).
+
+<div align="center">
+<h4>
+ <a href="#Inference"> Model Inference </a>
+锝�<a href="#Training"> Model Training and Testing </a>
+锝�<a href="#Export"> Model Export and Testing </a>
+</h4>
+</div>
+
+<a name="Inference"></a>
+## Model Inference
+
+### Quick Start
+
+For command-line invocation:
+```shell
+funasr ++model=paraformer-zh ++vad_model="fsmn-vad" ++punc_model="ct-punc" ++input=asr_example_zh.wav
+```
+
+For python code invocation (recommended):
+
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="paraformer-zh")
+
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
+print(res)
+```
+
+### API Description
+#### AutoModel Definition
+```python
+model = AutoModel(model=[str], device=[str], ncpu=[int], output_dir=[str], batch_size=[int], hub=[str], **kwargs)
+```
+- `model`(str): model name in the [Model Repository](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo), or a model path on local disk.
+- `device`(str): `cuda:0` (default gpu0) for using GPU for inference, specify `cpu` for using CPU.
+- `ncpu`(int): `4` (default), sets the number of threads for CPU internal operations.
+- `output_dir`(str): `None` (default), set this to specify the output path for the results.
+- `batch_size`(int): `1` (default), the number of samples per batch during decoding.
+- `hub`(str)锛歚ms` (default) to download models from ModelScope. Use `hf` to download models from Hugging Face.
+- `**kwargs`(dict): Any parameters found in config.yaml can be directly specified here, for instance, the maximum segmentation length in the vad model max_single_segment_time=6000 (milliseconds).
+
+#### AutoModel Inference
+```python
+res = model.generate(input=[str], output_dir=[str])
+```
+- `input`: The input to be decoded, which could be:
+ - A wav file path, e.g., asr_example.wav
+ - A pcm file path, e.g., asr_example.pcm, in this case, specify the audio sampling rate fs (default is 16000)
+ - An audio byte stream, e.g., byte data from a microphone
+ - A wav.scp, a Kaldi-style wav list (wav_id \t wav_path), for example:
+ ```text
+ asr_example1 ./audios/asr_example1.wav
+ asr_example2 ./audios/asr_example2.wav
+ ```
+ When using wav.scp as input, you must set output_dir to save the output results.
+ - Audio samples, `e.g.`: `audio, rate = soundfile.read("asr_example_zh.wav")`, data type is numpy.ndarray. Supports batch inputs, type is list锛�
+ ```[audio_sample1, audio_sample2, ..., audio_sampleN]```
+ - fbank input, supports batch grouping. Shape is [batch, frames, dim], type is torch.Tensor.
+- `output_dir`: None (default), if set, specifies the output path for the results.
+- `**kwargs`(dict): Inference parameters related to the model, for example,`beam_size=10`锛宍decoding_ctc_weight=0.1`.
+
+
+### More Usage Introduction
+
+
+#### Speech Recognition (Non-streaming)
+```python
+from funasr import AutoModel
+# paraformer-zh is a multi-functional asr model
+# use vad, punc, spk or not as you need
+model = AutoModel(model="paraformer-zh",
+ vad_model="fsmn-vad",
+ vad_kwargs={"max_single_segment_time": 60000},
+ punc_model="ct-punc",
+ # spk_model="cam++"
+ )
+wav_file = f"{model.model_path}/example/asr_example.wav"
+res = model.generate(input=wav_file, batch_size_s=300, batch_size_threshold_s=60, hotword='榄旀惌')
+print(res)
+```
+Notes:
+- Typically, the input duration for models is limited to under 30 seconds. However, when combined with `vad_model`, support for audio input of any length is enabled, not limited to the paraformer model鈥攁ny audio input model can be used.
+- Parameters related to model can be directly specified in the definition of AutoModel; parameters related to `vad_model` can be set through `vad_kwargs`, which is a dict; similar parameters include `punc_kwargs` and `spk_kwargs`.
+- `max_single_segment_time`: Denotes the maximum audio segmentation length for `vad_model`, measured in milliseconds (ms).
+- `batch_size_s` represents the use of dynamic batching, where the total audio duration within a batch is measured in seconds (s).
+- `batch_size_threshold_s`: Indicates that when the duration of an audio segment post-VAD segmentation exceeds the batch_size_threshold_s threshold, the batch size is set to 1, measured in seconds (s).
+
+Recommendations:
+
+When you input long audio and encounter Out Of Memory (OOM) issues, since memory usage tends to increase quadratically with audio length, consider the following three scenarios:
+
+a) At the beginning of inference, memory usage primarily depends on `batch_size_s`. Appropriately reducing this value can decrease memory usage.
+b) During the middle of inference, when encountering long audio segments cut by VAD and the total token count is less than `batch_size_s`, yet still facing OOM, you can appropriately reduce `batch_size_threshold_s`. If the threshold is exceeded, the batch size is forced to 1.
+c) Towards the end of inference, if long audio segments cut by VAD have a total token count less than `batch_size_s` and exceed the `threshold` batch_size_threshold_s, forcing the batch size to 1 and still facing OOM, you may reduce `max_single_segment_time` to shorten the VAD audio segment length.
+
+#### Speech Recognition (Streaming)
+```python
+from funasr import AutoModel
+
+chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
+encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
+decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
+
+model = AutoModel(model="paraformer-zh-streaming")
+
+import soundfile
+import os
+
+wav_file = os.path.join(model.model_path, "example/asr_example.wav")
+speech, sample_rate = soundfile.read(wav_file)
+chunk_stride = chunk_size[1] * 960 # 600ms
+
+cache = {}
+total_chunk_num = int(len((speech)-1)/chunk_stride+1)
+for i in range(total_chunk_num):
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
+ is_final = i == total_chunk_num - 1
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+ print(res)
+```
+Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
+
+#### Voice Activity Detection (Non-Streaming)
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="fsmn-vad")
+wav_file = f"{model.model_path}/example/asr_example.wav"
+res = model.generate(input=wav_file)
+print(res)
+```
+Note: The output format of the VAD model is: `[[beg1, end1], [beg2, end2], ..., [begN, endN]]`, where `begN/endN` indicates the starting/ending point of the `N-th` valid audio segment, measured in milliseconds.
+
+#### Voice Activity Detection (Streaming)
+```python
+from funasr import AutoModel
+
+chunk_size = 200 # ms
+model = AutoModel(model="fsmn-vad")
+
+import soundfile
+
+wav_file = f"{model.model_path}/example/vad_example.wav"
+speech, sample_rate = soundfile.read(wav_file)
+chunk_stride = int(chunk_size * sample_rate / 1000)
+
+cache = {}
+total_chunk_num = int(len((speech)-1)/chunk_stride+1)
+for i in range(total_chunk_num):
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
+ is_final = i == total_chunk_num - 1
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
+ if len(res[0]["value"]):
+ print(res)
+```
+Note: The output format for the streaming VAD model can be one of four scenarios:
+- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛歍he same as the offline VAD output result mentioned above.
+- `[[beg, -1]]`锛欼ndicates that only a starting point has been detected.
+- `[[-1, end]]`锛欼ndicates that only an ending point has been detected.
+- `[]`锛欼ndicates that neither a starting point nor an ending point has been detected.
+
+The output is measured in milliseconds and represents the absolute time from the starting point.
+#### Punctuation Restoration
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="ct-punc")
+res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
+print(res)
+```
+#### Timestamp Prediction
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="fa-zh")
+wav_file = f"{model.model_path}/example/asr_example.wav"
+text_file = f"{model.model_path}/example/text.txt"
+res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
+print(res)
+```
+
+More examples ref to [docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
+
+<a name="Training"></a>
+## Model Training and Testing
+
+### Quick Start
+
+Execute via command line (for quick testing, not recommended):
+```shell
+funasr-train ++model=paraformer-zh ++train_data_set_list=data/list/train.jsonl ++valid_data_set_list=data/list/val.jsonl ++output_dir="./outputs" &> log.txt &
+```
+
+Execute with Python code (supports multi-node and multi-GPU, recommended):
+
+```shell
+cd examples/industrial_data_pretraining/paraformer
+bash finetune.sh
+# "log_file: ./outputs/log.txt"
+```
+Full code ref to [finetune.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/examples/industrial_data_pretraining/paraformer/finetune.sh)
+
+### Detailed Parameter Description:
+
+```shell
+funasr/bin/train.py \
+++model="${model_name_or_model_dir}" \
+++train_data_set_list="${train_data}" \
+++valid_data_set_list="${val_data}" \
+++dataset_conf.batch_size=20000 \
+++dataset_conf.batch_type="token" \
+++dataset_conf.num_workers=4 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=1 \
+++train_conf.resume=false \
+++train_conf.validate_interval=2000 \
+++train_conf.save_checkpoint_interval=2000 \
+++train_conf.keep_nbest_models=20 \
+++train_conf.avg_nbest_model=5 \
+++optim_conf.lr=0.0002 \
+++output_dir="${output_dir}" &> ${log_file}
+```
+
+- `model`锛坰tr锛�: The name of the model (the ID in the model repository), at which point the script will automatically download the model to local storage; alternatively, the path to a model already downloaded locally.
+- `train_data_set_list`锛坰tr锛�: The path to the training data, typically in jsonl format, for specific details refer to [examples](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list).
+- `valid_data_set_list`锛坰tr锛夛細The path to the validation data, also generally in jsonl format, for specific details refer to examples](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list).
+- `dataset_conf.batch_type`锛坰tr锛夛細example (default), the type of batch. example means batches are formed with a fixed number of batch_size samples; length or token means dynamic batching, with total length or number of tokens of the batch equalling batch_size.
+- `dataset_conf.batch_size`锛坕nt锛夛細Used in conjunction with batch_type. When batch_type=example, it represents the number of samples; when batch_type=length, it represents the length of the samples, measured in fbank frames (1 frame = 10 ms) or the number of text tokens.
+- `train_conf.max_epoch`锛坕nt锛夛細The total number of epochs for training.
+- `train_conf.log_interval`锛坕nt锛夛細The number of steps between logging.
+- `train_conf.resume`锛坕nt锛夛細Whether to enable checkpoint resuming for training.
+- `train_conf.validate_interval`锛坕nt锛夛細The interval in steps to run validation tests during training.
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細The interval in steps for saving the model during training.
+- `train_conf.keep_nbest_models`锛坕nt锛夛細The maximum number of model parameters to retain, sorted by validation set accuracy, from highest to lowest.
+- `train_conf.avg_nbest_model`锛坕nt锛夛細Average over the top n models with the highest accuracy.
+- `optim_conf.lr`锛坒loat锛夛細The learning rate.
+- `output_dir`锛坰tr锛夛細The path for saving the model.
+- `**kwargs`(dict): Any parameters in config.yaml can be specified directly here, for example, to filter out audio longer than 20s: dataset_conf.max_token_length=2000, measured in fbank frames (1 frame = 10 ms) or the number of text tokens.
+
+#### Multi-GPU Training
+##### Single-Machine Multi-GPU Training
+```shell
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+../../../funasr/bin/train.py ${train_args}
+```
+--nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node.
+
+##### Multi-Machine Multi-GPU Training
+
+On the master node, assuming the IP is 192.168.1.1 and the port is 12345, and you're using 2 GPUs, you would run the following command:
+```shell
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
+../../../funasr/bin/train.py ${train_args}
+```
+On the worker node (assuming the IP is 192.168.1.2), you need to ensure that the MASTER_ADDR and MASTER_PORT environment variables are set to match those of the master node, and then run the same command:
+
+```shell
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
+../../../funasr/bin/train.py ${train_args}
+```
+
+--nnodes indicates the total number of nodes participating in the training, --node_rank represents the ID of the current node, and --nproc_per_node specifies the number of processes running on each node (usually corresponds to the number of GPUs).
+
+#### Data prepare
+
+`jsonl` ref to锛圼demo](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛�.
+The instruction scp2jsonl can be used to generate from wav.scp and text.txt. The preparation process for wav.scp and text.txt is as follows:
+
+`train_text.txt`
+
+```bash
+ID0012W0013 褰撳鎴烽闄╂壙鍙楄兘鍔涜瘎浼颁緷鎹彂鐢熷彉鍖栨椂
+ID0012W0014 鎵�鏈夊彧瑕佸鐞� data 涓嶇浣犳槸鍋� machine learning 鍋� deep learning
+ID0012W0015 he tried to think how it could be
+```
+
+
+`train_wav.scp`
+
+
+```bash
+BAC009S0764W0121 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0764W0121.wav
+BAC009S0916W0489 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0916W0489.wav
+ID0012W0015 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cn_en.wav
+```
+
+`Command`
+
+```shell
+# generate train.jsonl and val.jsonl from wav.scp and text.txt
+scp2jsonl \
+++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_out="../../../data/list/train.jsonl"
+```
+
+(Optional, not required) If you need to parse from jsonl back to wav.scp and text.txt, you can use the following command:
+
+```shell
+# generate wav.scp and text.txt from train.jsonl and val.jsonl
+jsonl2scp \
+++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_in="../../../data/list/train.jsonl"
+```
+
+#### Training log
+
+##### log.txt
+```shell
+tail log.txt
+[2024-03-21 15:55:52,137][root][INFO] - train, rank: 3, epoch: 0/50, step: 6990/1, total step: 6990, (loss_avg_rank: 0.327), (loss_avg_epoch: 0.409), (ppl_avg_epoch: 1.506), (acc_avg_epoch: 0.795), (lr: 1.165e-04), [('loss_att', 0.259), ('acc', 0.825), ('loss_pre', 0.04), ('loss', 0.299), ('batch_size', 40)], {'data_load': '0.000', 'forward_time': '0.315', 'backward_time': '0.555', 'optim_time': '0.076', 'total_time': '0.947'}, GPU, memory: usage: 3.830 GB, peak: 18.357 GB, cache: 20.910 GB, cache_peak: 20.910 GB
+[2024-03-21 15:55:52,139][root][INFO] - train, rank: 1, epoch: 0/50, step: 6990/1, total step: 6990, (loss_avg_rank: 0.334), (loss_avg_epoch: 0.409), (ppl_avg_epoch: 1.506), (acc_avg_epoch: 0.795), (lr: 1.165e-04), [('loss_att', 0.285), ('acc', 0.823), ('loss_pre', 0.046), ('loss', 0.331), ('batch_size', 36)], {'data_load': '0.000', 'forward_time': '0.334', 'backward_time': '0.536', 'optim_time': '0.077', 'total_time': '0.948'}, GPU, memory: usage: 3.943 GB, peak: 18.291 GB, cache: 19.619 GB, cache_peak: 19.619 GB
+```
+
+
+- `rank`锛歡pu id銆�
+- `epoch`,`step`,`total step`锛歵he current epoch, step, and total steps.
+- `loss_avg_rank`锛歵he average loss across all GPUs for the current step.
+- `loss/ppl/acc_avg_epoch`锛歵he overall average loss/perplexity/accuracy for the current epoch, up to the current step count. The last step of the epoch when it ends represents the total average loss/perplexity/accuracy for that epoch; it is recommended to use the accuracy metric.
+- `lr`锛歵he learning rate for the current step.
+- `[('loss_att', 0.259), ('acc', 0.825), ('loss_pre', 0.04), ('loss', 0.299), ('batch_size', 40)]`锛歵he specific data for the current GPU ID.
+- `total_time`锛歵he total time taken for a single step.
+- `GPU, memory`锛歵he model-used/peak memory and the model+cache-used/peak memory.
+
+##### tensorboard
+```bash
+tensorboard --logdir /xxxx/FunASR/examples/industrial_data_pretraining/paraformer/outputs/log/tensorboard
+```
+http://localhost:6006/
+
+### 璁粌鍚庢ā鍨嬫祴璇�
+
+
+#### With `configuration.json` file
+
+Assuming the training model path is: ./model_dir, if a configuration.json file has been generated in this directory, you only need to change the model name to the model path in the above model inference method.
+
+For example, for shell inference:
+```shell
+python -m funasr.bin.inference ++model="./model_dir" ++input=="${input}" ++output_dir="${output_dir}"
+```
+
+Python inference
+
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="./model_dir")
+
+res = model.generate(input=wav_file)
+print(res)
+```
+
+#### Without `configuration.json` file
+
+If there is no configuration.json in the model path, you need to manually specify the exact configuration file path and the model path.
+
+```shell
+python -m funasr.bin.inference \
+--config-path "${local_path}" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++tokenizer_conf.token_list="${tokens}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}"
+```
+
+Parameter Introduction
+- `config-path`锛歍his is the path to the config.yaml saved during the experiment, which can be found in the experiment's output directory.
+- `config-name`锛歍he name of the configuration file, usually config.yaml. It supports both YAML and JSON formats, for example config.json.
+- `init_param`锛歍he model parameters that need to be tested, usually model.pt. You can choose a specific model file as needed.
+- `tokenizer_conf.token_list`锛歍he path to the vocabulary file, which is normally specified in config.yaml. There is no need to manually specify it again unless the path in config.yaml is incorrect, in which case the correct path must be manually specified here.
+- `frontend_conf.cmvn_file`锛歍he CMVN (Cepstral Mean and Variance Normalization) file used when extracting fbank features from WAV files, which is usually specified in config.yaml. There is no need to manually specify it again unless the path in config.yaml is incorrect, in which case the correct path must be manually specified here.
+
+Other parameters are the same as mentioned above. A complete [example](https://github.com/alibaba-damo-academy/FunASR/blob/main/examples/industrial_data_pretraining/paraformer/infer_from_local.sh) can be found here.
+
+<a name="Export"></a>
+## Export ONNX
+
+### Command-line usage
+```shell
+funasr-export ++model=paraformer ++quantize=false ++device=cpu
+```
+
+### Python
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="paraformer", device="cpu")
+
+res = model.export(quantize=False)
+```
+
+### Test ONNX
+```python
+# pip3 install -U funasr-onnx
+from funasr_onnx import Paraformer
+model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1, quantize=True)
+
+wav_path = ['~/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+
+result = model(wav_path)
+print(result)
+```
+
+More examples ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/runtime/python/onnxruntime)
\ No newline at end of file
diff --git a/examples/README_zh.md b/examples/README_zh.md
index fa85290..4e9bb3f 100644
--- a/examples/README_zh.md
+++ b/examples/README_zh.md
@@ -235,13 +235,17 @@
- `valid_data_set_list`锛坰tr锛夛細楠岃瘉鏁版嵁璺緞锛岄粯璁や负jsonl鏍煎紡锛屽叿浣撳弬鑰冿紙[渚嬪瓙](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛夈��
- `dataset_conf.batch_type`锛坰tr锛夛細`example`锛堥粯璁わ級锛宐atch鐨勭被鍨嬨�俙example`琛ㄧず鎸夌収鍥哄畾鏁扮洰batch_size涓牱鏈粍batch锛沗length` or `token` 琛ㄧず鍔ㄦ�佺粍batch锛宐atch鎬婚暱搴︽垨鑰卼oken鏁颁负batch_size銆�
- `dataset_conf.batch_size`锛坕nt锛夛細涓� `batch_type` 鎼厤浣跨敤锛屽綋 `batch_type=example` 鏃讹紝琛ㄧず鏍锋湰涓暟锛涘綋 `batch_type=length` 鏃讹紝琛ㄧず鏍锋湰涓暱搴︼紝鍗曚綅涓篺bank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
-- `train_conf.max_epoch`锛坕nt锛夛細璁粌鎬籩poch鏁般��
-- `train_conf.log_interval`锛坕nt锛夛細鎵撳嵃鏃ュ織闂撮殧step鏁般��
-- `train_conf.resume`锛坕nt锛夛細鏄惁寮�鍚柇鐐归噸璁��
-- `train_conf.validate_interval`锛坕nt锛夛細璁粌涓仛楠岃瘉娴嬭瘯鐨勯棿闅攕tep鏁般��
-- `train_conf.save_checkpoint_interval`锛坕nt锛夛細璁粌涓ā鍨嬩繚瀛橀棿闅攕tep鏁般��
-- `train_conf.keep_nbest_models`锛坕nt锛夛細淇濈暀鏈�澶у灏戜釜妯″瀷鍙傛暟锛屾寜鐓ч獙璇侀泦acc鎺掑簭锛屼粠楂樺埌搴曚繚鐣欍��
-- `train_conf.avg_nbest_model`锛坕nt锛夛細瀵筧cc鏈�楂樼殑n涓ā鍨嬪彇骞冲潎銆�
+- `train_conf.max_epoch`锛坕nt锛夛細`100`锛堥粯璁わ級锛岃缁冩�籩poch鏁般��
+- `train_conf.log_interval`锛坕nt锛夛細`50`锛堥粯璁わ級锛屾墦鍗版棩蹇楅棿闅攕tep鏁般��
+- `train_conf.resume`锛坕nt锛夛細`True`锛堥粯璁わ級锛屾槸鍚﹀紑鍚柇鐐归噸璁��
+- `train_conf.validate_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑鍋氶獙璇佹祴璇曠殑闂撮殧step鏁般��
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑妯″瀷淇濆瓨闂撮殧step鏁般��
+- `train_conf.avg_keep_nbest_models_type`锛坰tr锛夛細`acc`锛堥粯璁わ級锛屼繚鐣檔best鐨勬爣鍑嗕负acc锛堣秺澶ц秺濂斤級銆俙loss`琛ㄧず锛屼繚鐣檔best鐨勬爣鍑嗕负loss锛堣秺灏忚秺濂斤級銆�
+- `train_conf.keep_nbest_models`锛坕nt锛夛細`500`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 淇濈暀鏈�浣崇殑n涓ā鍨嬶紝鍏朵粬鍒犻櫎锛岃妭绾﹀瓨鍌ㄧ┖闂淬��
+- `train_conf.avg_nbest_model`锛坕nt锛夛細`5`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 瀵规渶浣崇殑n涓ā鍨嬪钩鍧囥��
+- `train_conf.accum_grad`锛坕nt锛夛細`1`锛堥粯璁わ級锛屾搴︾疮绉姛鑳姐��
+- `train_conf.grad_clip`锛坒loat锛夛細`10.0`锛堥粯璁わ級锛屾搴︽埅鏂姛鑳姐��
+- `train_conf.use_fp16`锛坆ool锛夛細`False`锛堥粯璁わ級锛屽紑鍚痜p16璁粌锛屽姞蹇缁冮�熷害銆�
- `optim_conf.lr`锛坒loat锛夛細瀛︿範鐜囥��
- `output_dir`锛坰tr锛夛細妯″瀷淇濆瓨璺緞銆�
- `**kwargs`(dict): 鎵�鏈夊湪`config.yaml`涓弬鏁帮紝鍧囧彲浠ョ洿鎺ュ湪姝ゅ鎸囧畾锛屼緥濡傦紝杩囨护20s浠ヤ笂闀块煶棰戯細`dataset_conf.max_token_length=2000`锛屽崟浣嶄负闊抽fbank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/README_zh.md b/examples/industrial_data_pretraining/paraformer-zh-spk/README_zh.md
index fa85290..4e9bb3f 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/README_zh.md
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/README_zh.md
@@ -235,13 +235,17 @@
- `valid_data_set_list`锛坰tr锛夛細楠岃瘉鏁版嵁璺緞锛岄粯璁や负jsonl鏍煎紡锛屽叿浣撳弬鑰冿紙[渚嬪瓙](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛夈��
- `dataset_conf.batch_type`锛坰tr锛夛細`example`锛堥粯璁わ級锛宐atch鐨勭被鍨嬨�俙example`琛ㄧず鎸夌収鍥哄畾鏁扮洰batch_size涓牱鏈粍batch锛沗length` or `token` 琛ㄧず鍔ㄦ�佺粍batch锛宐atch鎬婚暱搴︽垨鑰卼oken鏁颁负batch_size銆�
- `dataset_conf.batch_size`锛坕nt锛夛細涓� `batch_type` 鎼厤浣跨敤锛屽綋 `batch_type=example` 鏃讹紝琛ㄧず鏍锋湰涓暟锛涘綋 `batch_type=length` 鏃讹紝琛ㄧず鏍锋湰涓暱搴︼紝鍗曚綅涓篺bank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
-- `train_conf.max_epoch`锛坕nt锛夛細璁粌鎬籩poch鏁般��
-- `train_conf.log_interval`锛坕nt锛夛細鎵撳嵃鏃ュ織闂撮殧step鏁般��
-- `train_conf.resume`锛坕nt锛夛細鏄惁寮�鍚柇鐐归噸璁��
-- `train_conf.validate_interval`锛坕nt锛夛細璁粌涓仛楠岃瘉娴嬭瘯鐨勯棿闅攕tep鏁般��
-- `train_conf.save_checkpoint_interval`锛坕nt锛夛細璁粌涓ā鍨嬩繚瀛橀棿闅攕tep鏁般��
-- `train_conf.keep_nbest_models`锛坕nt锛夛細淇濈暀鏈�澶у灏戜釜妯″瀷鍙傛暟锛屾寜鐓ч獙璇侀泦acc鎺掑簭锛屼粠楂樺埌搴曚繚鐣欍��
-- `train_conf.avg_nbest_model`锛坕nt锛夛細瀵筧cc鏈�楂樼殑n涓ā鍨嬪彇骞冲潎銆�
+- `train_conf.max_epoch`锛坕nt锛夛細`100`锛堥粯璁わ級锛岃缁冩�籩poch鏁般��
+- `train_conf.log_interval`锛坕nt锛夛細`50`锛堥粯璁わ級锛屾墦鍗版棩蹇楅棿闅攕tep鏁般��
+- `train_conf.resume`锛坕nt锛夛細`True`锛堥粯璁わ級锛屾槸鍚﹀紑鍚柇鐐归噸璁��
+- `train_conf.validate_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑鍋氶獙璇佹祴璇曠殑闂撮殧step鏁般��
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑妯″瀷淇濆瓨闂撮殧step鏁般��
+- `train_conf.avg_keep_nbest_models_type`锛坰tr锛夛細`acc`锛堥粯璁わ級锛屼繚鐣檔best鐨勬爣鍑嗕负acc锛堣秺澶ц秺濂斤級銆俙loss`琛ㄧず锛屼繚鐣檔best鐨勬爣鍑嗕负loss锛堣秺灏忚秺濂斤級銆�
+- `train_conf.keep_nbest_models`锛坕nt锛夛細`500`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 淇濈暀鏈�浣崇殑n涓ā鍨嬶紝鍏朵粬鍒犻櫎锛岃妭绾﹀瓨鍌ㄧ┖闂淬��
+- `train_conf.avg_nbest_model`锛坕nt锛夛細`5`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 瀵规渶浣崇殑n涓ā鍨嬪钩鍧囥��
+- `train_conf.accum_grad`锛坕nt锛夛細`1`锛堥粯璁わ級锛屾搴︾疮绉姛鑳姐��
+- `train_conf.grad_clip`锛坒loat锛夛細`10.0`锛堥粯璁わ級锛屾搴︽埅鏂姛鑳姐��
+- `train_conf.use_fp16`锛坆ool锛夛細`False`锛堥粯璁わ級锛屽紑鍚痜p16璁粌锛屽姞蹇缁冮�熷害銆�
- `optim_conf.lr`锛坒loat锛夛細瀛︿範鐜囥��
- `output_dir`锛坰tr锛夛細妯″瀷淇濆瓨璺緞銆�
- `**kwargs`(dict): 鎵�鏈夊湪`config.yaml`涓弬鏁帮紝鍧囧彲浠ョ洿鎺ュ湪姝ゅ鎸囧畾锛屼緥濡傦紝杩囨护20s浠ヤ笂闀块煶棰戯細`dataset_conf.max_token_length=2000`锛屽崟浣嶄负闊抽fbank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
diff --git a/examples/industrial_data_pretraining/paraformer/README.md b/examples/industrial_data_pretraining/paraformer/README.md
new file mode 100644
index 0000000..20102cc
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer/README.md
@@ -0,0 +1,424 @@
+([绠�浣撲腑鏂嘳(./README_zh.md)|English)
+
+FunASR has open-sourced a large number of pre-trained models on industrial data. You are free to use, copy, modify, and share FunASR models under the [Model License Agreement](https://github.com/alibaba-damo-academy/FunASR/blob/main/MODEL_LICENSE). Below, we list some representative models. For a comprehensive list, please refer to our [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo).
+
+<div align="center">
+<h4>
+ <a href="#Inference"> Model Inference </a>
+锝�<a href="#Training"> Model Training and Testing </a>
+锝�<a href="#Export"> Model Export and Testing </a>
+</h4>
+</div>
+
+<a name="Inference"></a>
+## Model Inference
+
+### Quick Start
+
+For command-line invocation:
+```shell
+funasr ++model=paraformer-zh ++vad_model="fsmn-vad" ++punc_model="ct-punc" ++input=asr_example_zh.wav
+```
+
+For python code invocation (recommended):
+
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="paraformer-zh")
+
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
+print(res)
+```
+
+### API Description
+#### AutoModel Definition
+```python
+model = AutoModel(model=[str], device=[str], ncpu=[int], output_dir=[str], batch_size=[int], hub=[str], **kwargs)
+```
+- `model`(str): model name in the [Model Repository](https://github.com/alibaba-damo-academy/FunASR/tree/main/model_zoo), or a model path on local disk.
+- `device`(str): `cuda:0` (default gpu0) for using GPU for inference, specify `cpu` for using CPU.
+- `ncpu`(int): `4` (default), sets the number of threads for CPU internal operations.
+- `output_dir`(str): `None` (default), set this to specify the output path for the results.
+- `batch_size`(int): `1` (default), the number of samples per batch during decoding.
+- `hub`(str)锛歚ms` (default) to download models from ModelScope. Use `hf` to download models from Hugging Face.
+- `**kwargs`(dict): Any parameters found in config.yaml can be directly specified here, for instance, the maximum segmentation length in the vad model max_single_segment_time=6000 (milliseconds).
+
+#### AutoModel Inference
+```python
+res = model.generate(input=[str], output_dir=[str])
+```
+- `input`: The input to be decoded, which could be:
+ - A wav file path, e.g., asr_example.wav
+ - A pcm file path, e.g., asr_example.pcm, in this case, specify the audio sampling rate fs (default is 16000)
+ - An audio byte stream, e.g., byte data from a microphone
+ - A wav.scp, a Kaldi-style wav list (wav_id \t wav_path), for example:
+ ```text
+ asr_example1 ./audios/asr_example1.wav
+ asr_example2 ./audios/asr_example2.wav
+ ```
+ When using wav.scp as input, you must set output_dir to save the output results.
+ - Audio samples, `e.g.`: `audio, rate = soundfile.read("asr_example_zh.wav")`, data type is numpy.ndarray. Supports batch inputs, type is list锛�
+ ```[audio_sample1, audio_sample2, ..., audio_sampleN]```
+ - fbank input, supports batch grouping. Shape is [batch, frames, dim], type is torch.Tensor.
+- `output_dir`: None (default), if set, specifies the output path for the results.
+- `**kwargs`(dict): Inference parameters related to the model, for example,`beam_size=10`锛宍decoding_ctc_weight=0.1`.
+
+
+### More Usage Introduction
+
+
+#### Speech Recognition (Non-streaming)
+```python
+from funasr import AutoModel
+# paraformer-zh is a multi-functional asr model
+# use vad, punc, spk or not as you need
+model = AutoModel(model="paraformer-zh",
+ vad_model="fsmn-vad",
+ vad_kwargs={"max_single_segment_time": 60000},
+ punc_model="ct-punc",
+ # spk_model="cam++"
+ )
+wav_file = f"{model.model_path}/example/asr_example.wav"
+res = model.generate(input=wav_file, batch_size_s=300, batch_size_threshold_s=60, hotword='榄旀惌')
+print(res)
+```
+Notes:
+- Typically, the input duration for models is limited to under 30 seconds. However, when combined with `vad_model`, support for audio input of any length is enabled, not limited to the paraformer model鈥攁ny audio input model can be used.
+- Parameters related to model can be directly specified in the definition of AutoModel; parameters related to `vad_model` can be set through `vad_kwargs`, which is a dict; similar parameters include `punc_kwargs` and `spk_kwargs`.
+- `max_single_segment_time`: Denotes the maximum audio segmentation length for `vad_model`, measured in milliseconds (ms).
+- `batch_size_s` represents the use of dynamic batching, where the total audio duration within a batch is measured in seconds (s).
+- `batch_size_threshold_s`: Indicates that when the duration of an audio segment post-VAD segmentation exceeds the batch_size_threshold_s threshold, the batch size is set to 1, measured in seconds (s).
+
+Recommendations:
+
+When you input long audio and encounter Out Of Memory (OOM) issues, since memory usage tends to increase quadratically with audio length, consider the following three scenarios:
+
+a) At the beginning of inference, memory usage primarily depends on `batch_size_s`. Appropriately reducing this value can decrease memory usage.
+b) During the middle of inference, when encountering long audio segments cut by VAD and the total token count is less than `batch_size_s`, yet still facing OOM, you can appropriately reduce `batch_size_threshold_s`. If the threshold is exceeded, the batch size is forced to 1.
+c) Towards the end of inference, if long audio segments cut by VAD have a total token count less than `batch_size_s` and exceed the `threshold` batch_size_threshold_s, forcing the batch size to 1 and still facing OOM, you may reduce `max_single_segment_time` to shorten the VAD audio segment length.
+
+#### Speech Recognition (Streaming)
+```python
+from funasr import AutoModel
+
+chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
+encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
+decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
+
+model = AutoModel(model="paraformer-zh-streaming")
+
+import soundfile
+import os
+
+wav_file = os.path.join(model.model_path, "example/asr_example.wav")
+speech, sample_rate = soundfile.read(wav_file)
+chunk_stride = chunk_size[1] * 960 # 600ms
+
+cache = {}
+total_chunk_num = int(len((speech)-1)/chunk_stride+1)
+for i in range(total_chunk_num):
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
+ is_final = i == total_chunk_num - 1
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+ print(res)
+```
+Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
+
+#### Voice Activity Detection (Non-Streaming)
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="fsmn-vad")
+wav_file = f"{model.model_path}/example/asr_example.wav"
+res = model.generate(input=wav_file)
+print(res)
+```
+Note: The output format of the VAD model is: `[[beg1, end1], [beg2, end2], ..., [begN, endN]]`, where `begN/endN` indicates the starting/ending point of the `N-th` valid audio segment, measured in milliseconds.
+
+#### Voice Activity Detection (Streaming)
+```python
+from funasr import AutoModel
+
+chunk_size = 200 # ms
+model = AutoModel(model="fsmn-vad")
+
+import soundfile
+
+wav_file = f"{model.model_path}/example/vad_example.wav"
+speech, sample_rate = soundfile.read(wav_file)
+chunk_stride = int(chunk_size * sample_rate / 1000)
+
+cache = {}
+total_chunk_num = int(len((speech)-1)/chunk_stride+1)
+for i in range(total_chunk_num):
+ speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
+ is_final = i == total_chunk_num - 1
+ res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size)
+ if len(res[0]["value"]):
+ print(res)
+```
+Note: The output format for the streaming VAD model can be one of four scenarios:
+- `[[beg1, end1], [beg2, end2], .., [begN, endN]]`锛歍he same as the offline VAD output result mentioned above.
+- `[[beg, -1]]`锛欼ndicates that only a starting point has been detected.
+- `[[-1, end]]`锛欼ndicates that only an ending point has been detected.
+- `[]`锛欼ndicates that neither a starting point nor an ending point has been detected.
+
+The output is measured in milliseconds and represents the absolute time from the starting point.
+#### Punctuation Restoration
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="ct-punc")
+res = model.generate(input="閭d粖澶╃殑浼氬氨鍒拌繖閲屽惂 happy new year 鏄庡勾瑙�")
+print(res)
+```
+#### Timestamp Prediction
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="fa-zh")
+wav_file = f"{model.model_path}/example/asr_example.wav"
+text_file = f"{model.model_path}/example/text.txt"
+res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
+print(res)
+```
+
+More examples ref to [docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
+
+<a name="Training"></a>
+## Model Training and Testing
+
+### Quick Start
+
+Execute via command line (for quick testing, not recommended):
+```shell
+funasr-train ++model=paraformer-zh ++train_data_set_list=data/list/train.jsonl ++valid_data_set_list=data/list/val.jsonl ++output_dir="./outputs" &> log.txt &
+```
+
+Execute with Python code (supports multi-node and multi-GPU, recommended):
+
+```shell
+cd examples/industrial_data_pretraining/paraformer
+bash finetune.sh
+# "log_file: ./outputs/log.txt"
+```
+Full code ref to [finetune.sh](https://github.com/alibaba-damo-academy/FunASR/blob/main/examples/industrial_data_pretraining/paraformer/finetune.sh)
+
+### Detailed Parameter Description:
+
+```shell
+funasr/bin/train.py \
+++model="${model_name_or_model_dir}" \
+++train_data_set_list="${train_data}" \
+++valid_data_set_list="${val_data}" \
+++dataset_conf.batch_size=20000 \
+++dataset_conf.batch_type="token" \
+++dataset_conf.num_workers=4 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=1 \
+++train_conf.resume=false \
+++train_conf.validate_interval=2000 \
+++train_conf.save_checkpoint_interval=2000 \
+++train_conf.keep_nbest_models=20 \
+++train_conf.avg_nbest_model=5 \
+++optim_conf.lr=0.0002 \
+++output_dir="${output_dir}" &> ${log_file}
+```
+
+- `model`锛坰tr锛�: The name of the model (the ID in the model repository), at which point the script will automatically download the model to local storage; alternatively, the path to a model already downloaded locally.
+- `train_data_set_list`锛坰tr锛�: The path to the training data, typically in jsonl format, for specific details refer to [examples](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list).
+- `valid_data_set_list`锛坰tr锛夛細The path to the validation data, also generally in jsonl format, for specific details refer to examples](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list).
+- `dataset_conf.batch_type`锛坰tr锛夛細example (default), the type of batch. example means batches are formed with a fixed number of batch_size samples; length or token means dynamic batching, with total length or number of tokens of the batch equalling batch_size.
+- `dataset_conf.batch_size`锛坕nt锛夛細Used in conjunction with batch_type. When batch_type=example, it represents the number of samples; when batch_type=length, it represents the length of the samples, measured in fbank frames (1 frame = 10 ms) or the number of text tokens.
+- `train_conf.max_epoch`锛坕nt锛夛細The total number of epochs for training.
+- `train_conf.log_interval`锛坕nt锛夛細The number of steps between logging.
+- `train_conf.resume`锛坕nt锛夛細Whether to enable checkpoint resuming for training.
+- `train_conf.validate_interval`锛坕nt锛夛細The interval in steps to run validation tests during training.
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細The interval in steps for saving the model during training.
+- `train_conf.keep_nbest_models`锛坕nt锛夛細The maximum number of model parameters to retain, sorted by validation set accuracy, from highest to lowest.
+- `train_conf.avg_nbest_model`锛坕nt锛夛細Average over the top n models with the highest accuracy.
+- `optim_conf.lr`锛坒loat锛夛細The learning rate.
+- `output_dir`锛坰tr锛夛細The path for saving the model.
+- `**kwargs`(dict): Any parameters in config.yaml can be specified directly here, for example, to filter out audio longer than 20s: dataset_conf.max_token_length=2000, measured in fbank frames (1 frame = 10 ms) or the number of text tokens.
+
+#### Multi-GPU Training
+##### Single-Machine Multi-GPU Training
+```shell
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+../../../funasr/bin/train.py ${train_args}
+```
+--nnodes represents the total number of participating nodes, while --nproc_per_node indicates the number of processes running on each node.
+
+##### Multi-Machine Multi-GPU Training
+
+On the master node, assuming the IP is 192.168.1.1 and the port is 12345, and you're using 2 GPUs, you would run the following command:
+```shell
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+torchrun --nnodes 2 --node_rank 0 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
+../../../funasr/bin/train.py ${train_args}
+```
+On the worker node (assuming the IP is 192.168.1.2), you need to ensure that the MASTER_ADDR and MASTER_PORT environment variables are set to match those of the master node, and then run the same command:
+
+```shell
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+torchrun --nnodes 2 --node_rank 1 --nproc_per_node ${gpu_num} --master_addr=192.168.1.1 --master_port=12345 \
+../../../funasr/bin/train.py ${train_args}
+```
+
+--nnodes indicates the total number of nodes participating in the training, --node_rank represents the ID of the current node, and --nproc_per_node specifies the number of processes running on each node (usually corresponds to the number of GPUs).
+
+#### Data prepare
+
+`jsonl` ref to锛圼demo](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛�.
+The instruction scp2jsonl can be used to generate from wav.scp and text.txt. The preparation process for wav.scp and text.txt is as follows:
+
+`train_text.txt`
+
+```bash
+ID0012W0013 褰撳鎴烽闄╂壙鍙楄兘鍔涜瘎浼颁緷鎹彂鐢熷彉鍖栨椂
+ID0012W0014 鎵�鏈夊彧瑕佸鐞� data 涓嶇浣犳槸鍋� machine learning 鍋� deep learning
+ID0012W0015 he tried to think how it could be
+```
+
+
+`train_wav.scp`
+
+
+```bash
+BAC009S0764W0121 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0764W0121.wav
+BAC009S0916W0489 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0916W0489.wav
+ID0012W0015 https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_cn_en.wav
+```
+
+`Command`
+
+```shell
+# generate train.jsonl and val.jsonl from wav.scp and text.txt
+scp2jsonl \
+++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_out="../../../data/list/train.jsonl"
+```
+
+(Optional, not required) If you need to parse from jsonl back to wav.scp and text.txt, you can use the following command:
+
+```shell
+# generate wav.scp and text.txt from train.jsonl and val.jsonl
+jsonl2scp \
+++scp_file_list='["../../../data/list/train_wav.scp", "../../../data/list/train_text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_in="../../../data/list/train.jsonl"
+```
+
+#### Training log
+
+##### log.txt
+```shell
+tail log.txt
+[2024-03-21 15:55:52,137][root][INFO] - train, rank: 3, epoch: 0/50, step: 6990/1, total step: 6990, (loss_avg_rank: 0.327), (loss_avg_epoch: 0.409), (ppl_avg_epoch: 1.506), (acc_avg_epoch: 0.795), (lr: 1.165e-04), [('loss_att', 0.259), ('acc', 0.825), ('loss_pre', 0.04), ('loss', 0.299), ('batch_size', 40)], {'data_load': '0.000', 'forward_time': '0.315', 'backward_time': '0.555', 'optim_time': '0.076', 'total_time': '0.947'}, GPU, memory: usage: 3.830 GB, peak: 18.357 GB, cache: 20.910 GB, cache_peak: 20.910 GB
+[2024-03-21 15:55:52,139][root][INFO] - train, rank: 1, epoch: 0/50, step: 6990/1, total step: 6990, (loss_avg_rank: 0.334), (loss_avg_epoch: 0.409), (ppl_avg_epoch: 1.506), (acc_avg_epoch: 0.795), (lr: 1.165e-04), [('loss_att', 0.285), ('acc', 0.823), ('loss_pre', 0.046), ('loss', 0.331), ('batch_size', 36)], {'data_load': '0.000', 'forward_time': '0.334', 'backward_time': '0.536', 'optim_time': '0.077', 'total_time': '0.948'}, GPU, memory: usage: 3.943 GB, peak: 18.291 GB, cache: 19.619 GB, cache_peak: 19.619 GB
+```
+
+
+- `rank`锛歡pu id銆�
+- `epoch`,`step`,`total step`锛歵he current epoch, step, and total steps.
+- `loss_avg_rank`锛歵he average loss across all GPUs for the current step.
+- `loss/ppl/acc_avg_epoch`锛歵he overall average loss/perplexity/accuracy for the current epoch, up to the current step count. The last step of the epoch when it ends represents the total average loss/perplexity/accuracy for that epoch; it is recommended to use the accuracy metric.
+- `lr`锛歵he learning rate for the current step.
+- `[('loss_att', 0.259), ('acc', 0.825), ('loss_pre', 0.04), ('loss', 0.299), ('batch_size', 40)]`锛歵he specific data for the current GPU ID.
+- `total_time`锛歵he total time taken for a single step.
+- `GPU, memory`锛歵he model-used/peak memory and the model+cache-used/peak memory.
+
+##### tensorboard
+```bash
+tensorboard --logdir /xxxx/FunASR/examples/industrial_data_pretraining/paraformer/outputs/log/tensorboard
+```
+http://localhost:6006/
+
+### 璁粌鍚庢ā鍨嬫祴璇�
+
+
+#### With `configuration.json` file
+
+Assuming the training model path is: ./model_dir, if a configuration.json file has been generated in this directory, you only need to change the model name to the model path in the above model inference method.
+
+For example, for shell inference:
+```shell
+python -m funasr.bin.inference ++model="./model_dir" ++input=="${input}" ++output_dir="${output_dir}"
+```
+
+Python inference
+
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="./model_dir")
+
+res = model.generate(input=wav_file)
+print(res)
+```
+
+#### Without `configuration.json` file
+
+If there is no configuration.json in the model path, you need to manually specify the exact configuration file path and the model path.
+
+```shell
+python -m funasr.bin.inference \
+--config-path "${local_path}" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++tokenizer_conf.token_list="${tokens}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}"
+```
+
+Parameter Introduction
+- `config-path`锛歍his is the path to the config.yaml saved during the experiment, which can be found in the experiment's output directory.
+- `config-name`锛歍he name of the configuration file, usually config.yaml. It supports both YAML and JSON formats, for example config.json.
+- `init_param`锛歍he model parameters that need to be tested, usually model.pt. You can choose a specific model file as needed.
+- `tokenizer_conf.token_list`锛歍he path to the vocabulary file, which is normally specified in config.yaml. There is no need to manually specify it again unless the path in config.yaml is incorrect, in which case the correct path must be manually specified here.
+- `frontend_conf.cmvn_file`锛歍he CMVN (Cepstral Mean and Variance Normalization) file used when extracting fbank features from WAV files, which is usually specified in config.yaml. There is no need to manually specify it again unless the path in config.yaml is incorrect, in which case the correct path must be manually specified here.
+
+Other parameters are the same as mentioned above. A complete [example](https://github.com/alibaba-damo-academy/FunASR/blob/main/examples/industrial_data_pretraining/paraformer/infer_from_local.sh) can be found here.
+
+<a name="Export"></a>
+## Export ONNX
+
+### Command-line usage
+```shell
+funasr-export ++model=paraformer ++quantize=false ++device=cpu
+```
+
+### Python
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="paraformer", device="cpu")
+
+res = model.export(quantize=False)
+```
+
+### Test ONNX
+```python
+# pip3 install -U funasr-onnx
+from funasr_onnx import Paraformer
+model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1, quantize=True)
+
+wav_path = ['~/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+
+result = model(wav_path)
+print(result)
+```
+
+More examples ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/runtime/python/onnxruntime)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/README_zh.md b/examples/industrial_data_pretraining/paraformer/README_zh.md
index fa85290..4e9bb3f 100644
--- a/examples/industrial_data_pretraining/paraformer/README_zh.md
+++ b/examples/industrial_data_pretraining/paraformer/README_zh.md
@@ -235,13 +235,17 @@
- `valid_data_set_list`锛坰tr锛夛細楠岃瘉鏁版嵁璺緞锛岄粯璁や负jsonl鏍煎紡锛屽叿浣撳弬鑰冿紙[渚嬪瓙](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛夈��
- `dataset_conf.batch_type`锛坰tr锛夛細`example`锛堥粯璁わ級锛宐atch鐨勭被鍨嬨�俙example`琛ㄧず鎸夌収鍥哄畾鏁扮洰batch_size涓牱鏈粍batch锛沗length` or `token` 琛ㄧず鍔ㄦ�佺粍batch锛宐atch鎬婚暱搴︽垨鑰卼oken鏁颁负batch_size銆�
- `dataset_conf.batch_size`锛坕nt锛夛細涓� `batch_type` 鎼厤浣跨敤锛屽綋 `batch_type=example` 鏃讹紝琛ㄧず鏍锋湰涓暟锛涘綋 `batch_type=length` 鏃讹紝琛ㄧず鏍锋湰涓暱搴︼紝鍗曚綅涓篺bank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
-- `train_conf.max_epoch`锛坕nt锛夛細璁粌鎬籩poch鏁般��
-- `train_conf.log_interval`锛坕nt锛夛細鎵撳嵃鏃ュ織闂撮殧step鏁般��
-- `train_conf.resume`锛坕nt锛夛細鏄惁寮�鍚柇鐐归噸璁��
-- `train_conf.validate_interval`锛坕nt锛夛細璁粌涓仛楠岃瘉娴嬭瘯鐨勯棿闅攕tep鏁般��
-- `train_conf.save_checkpoint_interval`锛坕nt锛夛細璁粌涓ā鍨嬩繚瀛橀棿闅攕tep鏁般��
-- `train_conf.keep_nbest_models`锛坕nt锛夛細淇濈暀鏈�澶у灏戜釜妯″瀷鍙傛暟锛屾寜鐓ч獙璇侀泦acc鎺掑簭锛屼粠楂樺埌搴曚繚鐣欍��
-- `train_conf.avg_nbest_model`锛坕nt锛夛細瀵筧cc鏈�楂樼殑n涓ā鍨嬪彇骞冲潎銆�
+- `train_conf.max_epoch`锛坕nt锛夛細`100`锛堥粯璁わ級锛岃缁冩�籩poch鏁般��
+- `train_conf.log_interval`锛坕nt锛夛細`50`锛堥粯璁わ級锛屾墦鍗版棩蹇楅棿闅攕tep鏁般��
+- `train_conf.resume`锛坕nt锛夛細`True`锛堥粯璁わ級锛屾槸鍚﹀紑鍚柇鐐归噸璁��
+- `train_conf.validate_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑鍋氶獙璇佹祴璇曠殑闂撮殧step鏁般��
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑妯″瀷淇濆瓨闂撮殧step鏁般��
+- `train_conf.avg_keep_nbest_models_type`锛坰tr锛夛細`acc`锛堥粯璁わ級锛屼繚鐣檔best鐨勬爣鍑嗕负acc锛堣秺澶ц秺濂斤級銆俙loss`琛ㄧず锛屼繚鐣檔best鐨勬爣鍑嗕负loss锛堣秺灏忚秺濂斤級銆�
+- `train_conf.keep_nbest_models`锛坕nt锛夛細`500`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 淇濈暀鏈�浣崇殑n涓ā鍨嬶紝鍏朵粬鍒犻櫎锛岃妭绾﹀瓨鍌ㄧ┖闂淬��
+- `train_conf.avg_nbest_model`锛坕nt锛夛細`5`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 瀵规渶浣崇殑n涓ā鍨嬪钩鍧囥��
+- `train_conf.accum_grad`锛坕nt锛夛細`1`锛堥粯璁わ級锛屾搴︾疮绉姛鑳姐��
+- `train_conf.grad_clip`锛坒loat锛夛細`10.0`锛堥粯璁わ級锛屾搴︽埅鏂姛鑳姐��
+- `train_conf.use_fp16`锛坆ool锛夛細`False`锛堥粯璁わ級锛屽紑鍚痜p16璁粌锛屽姞蹇缁冮�熷害銆�
- `optim_conf.lr`锛坒loat锛夛細瀛︿範鐜囥��
- `output_dir`锛坰tr锛夛細妯″瀷淇濆瓨璺緞銆�
- `**kwargs`(dict): 鎵�鏈夊湪`config.yaml`涓弬鏁帮紝鍧囧彲浠ョ洿鎺ュ湪姝ゅ鎸囧畾锛屼緥濡傦紝杩囨护20s浠ヤ笂闀块煶棰戯細`dataset_conf.max_token_length=2000`锛屽崟浣嶄负闊抽fbank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/README_zh.md b/examples/industrial_data_pretraining/paraformer_streaming/README_zh.md
index fa85290..4e9bb3f 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/README_zh.md
+++ b/examples/industrial_data_pretraining/paraformer_streaming/README_zh.md
@@ -235,13 +235,17 @@
- `valid_data_set_list`锛坰tr锛夛細楠岃瘉鏁版嵁璺緞锛岄粯璁や负jsonl鏍煎紡锛屽叿浣撳弬鑰冿紙[渚嬪瓙](https://github.com/alibaba-damo-academy/FunASR/blob/main/data/list)锛夈��
- `dataset_conf.batch_type`锛坰tr锛夛細`example`锛堥粯璁わ級锛宐atch鐨勭被鍨嬨�俙example`琛ㄧず鎸夌収鍥哄畾鏁扮洰batch_size涓牱鏈粍batch锛沗length` or `token` 琛ㄧず鍔ㄦ�佺粍batch锛宐atch鎬婚暱搴︽垨鑰卼oken鏁颁负batch_size銆�
- `dataset_conf.batch_size`锛坕nt锛夛細涓� `batch_type` 鎼厤浣跨敤锛屽綋 `batch_type=example` 鏃讹紝琛ㄧず鏍锋湰涓暟锛涘綋 `batch_type=length` 鏃讹紝琛ㄧず鏍锋湰涓暱搴︼紝鍗曚綅涓篺bank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
-- `train_conf.max_epoch`锛坕nt锛夛細璁粌鎬籩poch鏁般��
-- `train_conf.log_interval`锛坕nt锛夛細鎵撳嵃鏃ュ織闂撮殧step鏁般��
-- `train_conf.resume`锛坕nt锛夛細鏄惁寮�鍚柇鐐归噸璁��
-- `train_conf.validate_interval`锛坕nt锛夛細璁粌涓仛楠岃瘉娴嬭瘯鐨勯棿闅攕tep鏁般��
-- `train_conf.save_checkpoint_interval`锛坕nt锛夛細璁粌涓ā鍨嬩繚瀛橀棿闅攕tep鏁般��
-- `train_conf.keep_nbest_models`锛坕nt锛夛細淇濈暀鏈�澶у灏戜釜妯″瀷鍙傛暟锛屾寜鐓ч獙璇侀泦acc鎺掑簭锛屼粠楂樺埌搴曚繚鐣欍��
-- `train_conf.avg_nbest_model`锛坕nt锛夛細瀵筧cc鏈�楂樼殑n涓ā鍨嬪彇骞冲潎銆�
+- `train_conf.max_epoch`锛坕nt锛夛細`100`锛堥粯璁わ級锛岃缁冩�籩poch鏁般��
+- `train_conf.log_interval`锛坕nt锛夛細`50`锛堥粯璁わ級锛屾墦鍗版棩蹇楅棿闅攕tep鏁般��
+- `train_conf.resume`锛坕nt锛夛細`True`锛堥粯璁わ級锛屾槸鍚﹀紑鍚柇鐐归噸璁��
+- `train_conf.validate_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑鍋氶獙璇佹祴璇曠殑闂撮殧step鏁般��
+- `train_conf.save_checkpoint_interval`锛坕nt锛夛細`5000`锛堥粯璁わ級锛岃缁冧腑妯″瀷淇濆瓨闂撮殧step鏁般��
+- `train_conf.avg_keep_nbest_models_type`锛坰tr锛夛細`acc`锛堥粯璁わ級锛屼繚鐣檔best鐨勬爣鍑嗕负acc锛堣秺澶ц秺濂斤級銆俙loss`琛ㄧず锛屼繚鐣檔best鐨勬爣鍑嗕负loss锛堣秺灏忚秺濂斤級銆�
+- `train_conf.keep_nbest_models`锛坕nt锛夛細`500`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 淇濈暀鏈�浣崇殑n涓ā鍨嬶紝鍏朵粬鍒犻櫎锛岃妭绾﹀瓨鍌ㄧ┖闂淬��
+- `train_conf.avg_nbest_model`锛坕nt锛夛細`5`锛堥粯璁わ級锛屼繚鐣欐渶澶у灏戜釜妯″瀷鍙傛暟锛岄厤鍚� `avg_keep_nbest_models_type` 鎸夌収楠岃瘉闆� acc/loss 瀵规渶浣崇殑n涓ā鍨嬪钩鍧囥��
+- `train_conf.accum_grad`锛坕nt锛夛細`1`锛堥粯璁わ級锛屾搴︾疮绉姛鑳姐��
+- `train_conf.grad_clip`锛坒loat锛夛細`10.0`锛堥粯璁わ級锛屾搴︽埅鏂姛鑳姐��
+- `train_conf.use_fp16`锛坆ool锛夛細`False`锛堥粯璁わ級锛屽紑鍚痜p16璁粌锛屽姞蹇缁冮�熷害銆�
- `optim_conf.lr`锛坒loat锛夛細瀛︿範鐜囥��
- `output_dir`锛坰tr锛夛細妯″瀷淇濆瓨璺緞銆�
- `**kwargs`(dict): 鎵�鏈夊湪`config.yaml`涓弬鏁帮紝鍧囧彲浠ョ洿鎺ュ湪姝ゅ鎸囧畾锛屼緥濡傦紝杩囨护20s浠ヤ笂闀块煶棰戯細`dataset_conf.max_token_length=2000`锛屽崟浣嶄负闊抽fbank甯ф暟锛�1甯�10ms锛夋垨鑰呮枃瀛梩oken涓暟銆�
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index d19b79a..880bb63 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -214,7 +214,7 @@
if trainer.rank == 0:
- average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
+ average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
trainer.close()
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 3603a44..013a719 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -16,139 +16,38 @@
from functools import cmp_to_key
-# @torch.no_grad()
-# def average_nbest_models(
-# output_dir: Path,
-# best_model_criterion: Sequence[Sequence[str]],
-# nbest: Union[Collection[int], int],
-# suffix: Optional[str] = None,
-# oss_bucket=None,
-# pai_output_dir=None,
-# ) -> None:
-# """Generate averaged model from n-best models
-#
-# Args:
-# output_dir: The directory contains the model file for each epoch
-# reporter: Reporter instance
-# best_model_criterion: Give criterions to decide the best model.
-# e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
-# nbest: Number of best model files to be averaged
-# suffix: A suffix added to the averaged model file name
-# """
-# if isinstance(nbest, int):
-# nbests = [nbest]
-# else:
-# nbests = list(nbest)
-# if len(nbests) == 0:
-# warnings.warn("At least 1 nbest values are required")
-# nbests = [1]
-# if suffix is not None:
-# suffix = suffix + "."
-# else:
-# suffix = ""
-#
-# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
-# nbest_epochs = [
-# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
-# for ph, k, m in best_model_criterion
-# if reporter.has(ph, k)
-# ]
-#
-# _loaded = {}
-# for ph, cr, epoch_and_values in nbest_epochs:
-# _nbests = [i for i in nbests if i <= len(epoch_and_values)]
-# if len(_nbests) == 0:
-# _nbests = [1]
-#
-# for n in _nbests:
-# if n == 0:
-# continue
-# elif n == 1:
-# # The averaged model is same as the best model
-# e, _ = epoch_and_values[0]
-# op = output_dir / f"{e}epoch.pb"
-# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
-# if sym_op.is_symlink() or sym_op.exists():
-# sym_op.unlink()
-# sym_op.symlink_to(op.name)
-# else:
-# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
-# logging.info(
-# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
-# )
-#
-# avg = None
-# # 2.a. Averaging model
-# for e, _ in epoch_and_values[:n]:
-# if e not in _loaded:
-# if oss_bucket is None:
-# _loaded[e] = torch.load(
-# output_dir / f"{e}epoch.pb",
-# map_location="cpu",
-# )
-# else:
-# buffer = BytesIO(
-# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
-# _loaded[e] = torch.load(buffer)
-# states = _loaded[e]
-#
-# if avg is None:
-# avg = states
-# else:
-# # Accumulated
-# for k in avg:
-# avg[k] = avg[k] + states[k]
-# for k in avg:
-# if str(avg[k].dtype).startswith("torch.int"):
-# # For int type, not averaged, but only accumulated.
-# # e.g. BatchNorm.num_batches_tracked
-# # (If there are any cases that requires averaging
-# # or the other reducing method, e.g. max/min, for integer type,
-# # please report.)
-# pass
-# else:
-# avg[k] = avg[k] / n
-#
-# # 2.b. Save the ave model and create a symlink
-# if oss_bucket is None:
-# torch.save(avg, op)
-# else:
-# buffer = BytesIO()
-# torch.save(avg, buffer)
-# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
-# buffer.getvalue())
-#
-# # 3. *.*.ave.pb is a symlink to the max ave model
-# if oss_bucket is None:
-# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
-# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
-# if sym_op.is_symlink() or sym_op.exists():
-# sym_op.unlink()
-# sym_op.symlink_to(op.name)
-
def _get_checkpoint_paths(output_dir: str, last_n: int=5):
"""
Get the paths of the last 'last_n' checkpoints by parsing filenames
in the output directory.
"""
- # List all files in the output directory
- files = os.listdir(output_dir)
- # Filter out checkpoint files and extract epoch numbers
- checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
- # Sort files by epoch number in descending order
- checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
- # Get the last 'last_n' checkpoint paths
- checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+ try:
+ checkpoint = torch.load(os.path.exists(os.path.join(output_dir, "model.pt")), map_location="cpu")
+ avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
+ val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
+ sorted_items = sorted(saved_ckpts.items(), key=lambda x: x[1], reverse=True)
+ sorted_items = sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
+ checkpoint_paths = [os.path.join(output_dir, key) for key, value in sorted_items[:last_n]]
+ except:
+ # List all files in the output directory
+ files = os.listdir(output_dir)
+ # Filter out checkpoint files and extract epoch numbers
+ checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
+ # Sort files by epoch number in descending order
+ checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
+ # Get the last 'last_n' checkpoint paths
+ checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
return checkpoint_paths
@torch.no_grad()
-def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
+def average_checkpoints(output_dir: str, last_n: int=5, **kwargs):
"""
Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
"""
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
+ print(f"average_checkpoints: {checkpoint_paths}")
state_dicts = []
# Load state_dicts from checkpoints
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 56ec604..116c9e3 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -71,21 +71,19 @@
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
self.device = kwargs.get('device', "cuda")
- self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
# self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.use_fp16 = use_fp16
- self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
- # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
- # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
- # self.scaler = scaler
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
- self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
+ self.validate_interval = kwargs.get("validate_interval", 5000)
+ self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
+ self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
+ self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
self.accum_grad = kwargs.get("accum_grad", 1)
self.grad_clip = kwargs.get("grad_clip", 10.0)
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
- self.validate_interval = kwargs.get("validate_interval", 5000)
+
try:
@@ -103,8 +101,10 @@
self.val_loss_avg = 0.0
self.best_acc_idx = 0
self.saved_ckpts = {}
- self.val_acc_list = []
self.step_or_epoch = -1
+ self.best_step_or_epoch = ""
+ self.val_acc_step_or_eoch = {}
+ self.val_loss_step_or_eoch = {}
def save_checkpoint(self, epoch,
step=None,
@@ -124,14 +124,17 @@
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
- self.step_or_epoch += 1
+ # self.step_or_epoch += 1
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optim.state_dict(),
'scheduler': scheduler.state_dict(),
- "acc": self.val_acc_list,
- "step_or_epoch": self.step_or_epoch,
+ "saved_ckpts": self.saved_ckpts,
+ "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
+ "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "best_step_or_epoch": self.best_step_or_epoch,
+ "avg_keep_nbest_models_type": slef.avg_keep_nbest_models_type,
}
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
@@ -150,23 +153,37 @@
logging.info(f'\nCheckpoint saved to {filename}\n')
latest = Path(os.path.join(self.output_dir, f'model.pt'))
torch.save(state, latest)
-
- if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
- self.best_acc_idx = self.step_or_epoch
- best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
- torch.save(state, best_ckpt)
- logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+ if self.best_step_or_epoch == "":
+ self.best_step_or_epoch = ckpt_name
+
+ if self.avg_keep_nbest_models_type == "acc":
+ if self.val_acc_step_or_eoch[ckpt_name] >= self.val_acc_step_or_eoch[self.best_step_or_epoch]:
+ self.best_step_or_epoch = ckpt_name
+ best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+ torch.save(state, best_ckpt)
+ logging.info(f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]}, {best_ckpt}")
+ else:
+ logging.info(f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]}")
+ elif self.avg_keep_nbest_models_type == "loss":
+ if self.val_loss_step_or_eoch[ckpt_name] <= self.val_loss_step_or_eoch[self.best_step_or_epoch]:
+ self.best_step_or_epoch = ckpt_name
+ best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+ torch.save(state, best_ckpt)
+ logging.info(f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]}, {best_ckpt}")
+ else:
+ logging.info(f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]}")
else:
- logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
-
+ print("Undo")
+ self.saved_ckpts[ckpt_name] = getattr(self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch")[ckpt_name]
if self.keep_nbest_models > 0:
- self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
if len(self.saved_ckpts) > self.keep_nbest_models:
-
- min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
- if min_key in self.saved_ckpts:
- del self.saved_ckpts[min_key]
- filename = os.path.join(self.output_dir, min_key)
+ if self.avg_keep_nbest_models_type == "acc":
+ key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+ else:
+ key = max(self.saved_ckpts, key=self.saved_ckpts.get)
+ if key in self.saved_ckpts:
+ del self.saved_ckpts[key]
+ filename = os.path.join(self.output_dir, key)
logging.info(f"Delete: {filename}")
if os.path.exists(filename):
os.remove(filename)
@@ -213,8 +230,10 @@
if scaler is not None and 'scaler_state' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state'])
- self.val_acc_list = checkpoint["acc"]
- self.step_or_epoch = checkpoint["step_or_epoch"]
+ self.saved_ckpts = checkpoint["saved_ckpts"]
+ self.val_acc_step_or_eoch = checkpoint["val_acc_step_or_eoch"] if "val_acc_step_or_eoch" in checkpoint else {}
+ self.val_loss_step_or_eoch = checkpoint["val_loss_step_or_eoch"] if "val_loss_step_or_eoch" in checkpoint else {}
+ self.val_loss_step_or_eoch = checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
@@ -458,8 +477,13 @@
if self.use_ddp or self.use_fsdp:
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
-
- self.val_acc_list.append(self.val_acc_avg)
+
+ if kwargs.get("step", None) is None:
+ ckpt_name = f'model.pt.ep{epoch}'
+ else:
+ ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
+ self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
+ self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
model.train()
if self.use_ddp or self.use_fsdp:
--
Gitblit v1.9.1