From 4e0fcee2a915641e7f39d62c389bee561d849e19 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 19 六月 2023 20:28:23 +0800
Subject: [PATCH] Merge branch 'main' into dev_wjm_infer
---
funasr/runtime/onnxruntime/src/ct-transformer.cpp | 1
egs/alimeeting/sa_asr/local/data/get_utt2dur.sh | 0
egs/alimeeting/sa_asr/local/gen_oracle_embedding.py | 6
egs/alimeeting/sa_asr/local/download_xvector_model.py | 0
egs/alimeeting/sa_asr/local/fix_data_dir.sh | 0
egs/alimeeting/sa_asr/local/text_normalize.pl | 0
egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml | 0
funasr/runtime/ssl_key/readme.md | 4
egs/alimeeting/sa_asr/README.md | 86 ++
egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl | 0
funasr/runtime/java/Makefile | 76 ++
funasr/tasks/sa_asr.py | 6
egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl | 0
funasr/runtime/ssl_key/server.key | 38
egs/alimeeting/sa_asr/local/compute_cmvn.py | 134 +++
egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh | 0
egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml | 0
egs/alimeeting/sa_asr_deprecated/path.sh | 0
funasr/build_utils/build_asr_model.py | 68 +
egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml | 0
egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py | 6
egs/alimeeting/sa_asr/local/format_wav_scp.sh | 0
README.md | 5
egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh | 50
funasr/bin/asr_infer.py | 6
egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh | 0
egs/alimeeting/sa_asr/utils | 0
egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh | 0
egs/alimeeting/sa_asr/local/data/get_reco2dur.sh | 0
egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py | 0
egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py | 3
egs/alimeeting/sa_asr/local/text_format.pl | 0
egs/alimeeting/sa_asr/local/convert_model.py | 29
egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py | 0
egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py | 4
egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py | 0
egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml | 0
egs/alimeeting/sa_asr_deprecated/local | 1
funasr/version.txt | 2
egs/alimeeting/sa_asr/local/combine_data.sh | 0
funasr/models/e2e_sa_asr.py | 17
egs/alimeeting/sa_asr/local/validate_text.pl | 0
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml | 102 ++
egs/alimeeting/sa_asr/local/data/split_data.sh | 0
egs/alimeeting/sa_asr/local/process_text_spk_merge.py | 0
egs/alimeeting/sa_asr_deprecated/README.md | 0
egs/alimeeting/sa_asr/run.sh | 435 ++++++++++++
egs/alimeeting/sa_asr/local/compute_cpcer.py | 0
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml | 131 +++
funasr/bin/train.py | 2
egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh | 0
egs/alimeeting/sa_asr_deprecated/run.sh | 0
egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py | 0
egs/alimeeting/sa_asr/local/apply_map.pl | 0
egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py | 0
egs/alimeeting/sa_asr/local/process_text_id.py | 0
egs/alimeeting/sa_asr_deprecated/asr_local.sh | 0
funasr/runtime/python/onnxruntime/setup.py | 4
egs/alimeeting/sa_asr/local/format_wav_scp.py | 0
egs/alimeeting/sa_asr/local/compute_cmvn.sh | 39 +
egs/alimeeting/sa_asr_deprecated/utils | 0
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 2
egs/alimeeting/sa_asr/local/download_and_untar.sh | 105 ++
funasr/runtime/ssl_key/server.crt | 32
funasr/tasks/asr.py | 152 ++++
egs/alimeeting/sa_asr/local/copy_data_dir.sh | 0
egs/alimeeting/sa_asr/path.sh | 6
/dev/null | 29
funasr/runtime/java/readme.md | 66 +
funasr/runtime/websocket/funasr-wss-client.cpp | 5
egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py | 0
egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh | 0
funasr/models/frontend/default.py | 117 ++
funasr/runtime/onnxruntime/src/paraformer.cpp | 1
egs/alimeeting/sa_asr/local/validate_data_dir.sh | 0
funasr/runtime/java/FunasrWsClient.java | 344 +++++++++
funasr/bin/asr_inference_launch.py | 7
77 files changed, 1,961 insertions(+), 160 deletions(-)
diff --git a/README.md b/README.md
index 7c289e0..76e3301 100644
--- a/README.md
+++ b/README.md
@@ -72,8 +72,8 @@
## Contributors
-| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/DeepScience.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
-|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|:-----------------------------------------------------------:|
+| <div align="left"><img src="docs/images/damo.png" width="180"/> | <div align="left"><img src="docs/images/nwpu.png" width="260"/> | <img src="docs/images/China_Telecom.png" width="200"/> </div> | <img src="docs/images/RapidAI.png" width="200"/> </div> | <img src="docs/images/aihealthx.png" width="200"/> </div> |
+|:---------------------------------------------------------------:|:---------------------------------------------------------------:|:--------------------------------------------------------------:|:-------------------------------------------------------:|:-----------------------------------------------------------:|
## Acknowledge
@@ -82,7 +82,6 @@
3. We referred [Wenet](https://github.com/wenet-e2e/wenet) for building dataloader for large scale data training.
4. We acknowledge [ChinaTelecom](https://github.com/zhuzizyf/damo-fsmn-vad-infer-httpserver) for contributing the VAD runtime.
5. We acknowledge [RapidAI](https://github.com/RapidAI) for contributing the Paraformer and CT_Transformer-punc runtime.
-6. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
6. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the websocket service and html5.
## License
diff --git a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml b/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
deleted file mode 100644
index 68520ae..0000000
--- a/egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-lm: transformer
-lm_conf:
- pos_enc: null
- embed_unit: 128
- att_unit: 512
- head: 8
- unit: 2048
- layer: 16
- dropout_rate: 0.1
-
-# optimization related
-grad_clip: 5.0
-batch_type: numel
-batch_bins: 500000 # 4gpus * 500000
-accum_grad: 1
-max_epoch: 15 # 15epoch is enougth
-
-optim: adam
-optim_conf:
- lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 25000
-
-best_model_criterion:
-- - valid
- - loss
- - min
-keep_nbest_models: 10 # 10 is good.
diff --git a/egs/alimeeting/sa_asr/README.md b/egs/alimeeting/sa_asr/README.md
new file mode 100644
index 0000000..1ae023a
--- /dev/null
+++ b/egs/alimeeting/sa_asr/README.md
@@ -0,0 +1,86 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
+# Train
+First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
+After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
+```shell
+dataset
+|鈥斺�� Eval_Ali_far
+|鈥斺�� Eval_Ali_near
+|鈥斺�� Test_Ali_far
+|鈥斺�� Test_Ali_near
+|鈥斺�� Train_Ali_far
+|鈥斺�� Train_Ali_near
+```
+Then you can run this receipe by running:
+```shell
+bash run.sh --stage 0 --stop-stage 6
+```
+There are 8 stages in `run.sh`:
+```shell
+stage 0: Data preparation and remove the audio which is too long or too short.
+stage 1: Speaker profile and CMVN Generation.
+stage 2: Dictionary preparation.
+stage 3: LM training (not supported).
+stage 4: ASR Training.
+stage 5: SA-ASR Training.
+stage 6: Inference
+stage 7: Inference with Test_2023_Ali_far
+```
+<!-- The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary). -->
+# Infer
+1. Download the final test set and extracted
+2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
+3. Set the `test_2023` in `run.sh` should be to `Test_2023_Ali_far`.
+4. Run the `run.sh` as follow.
+```shell
+# Prepare test_2023 set
+bash run.sh --stage 0 --stop-stage 1
+# Decode test_2023 set
+bash run.sh --stage 7 --stop-stage 7
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.
+<!-- <table>
+ <tr >
+ <td rowspan="2"></td>
+ <td colspan="2">SI-CER(%)</td>
+ <td colspan="2">cpCER(%)</td>
+ </tr>
+ <tr>
+ <td>Eval</td>
+ <td>Test</td>
+ <td>Eval</td>
+ <td>Test</td>
+ </tr>
+ <tr>
+ <td>oracle profile</td>
+ <td>32.05</td>
+ <td>32.72</td>
+ <td>47.40</td>
+ <td>42.92</td>
+ </tr>
+ <tr>
+ <td>cluster profile</td>
+ <td>32.05</td>
+ <td>32.73</td>
+ <td>53.76</td>
+ <td>49.37</td>
+ </tr>
+</table> -->
+| |SI-CER(%) |cp-CER(%) |
+|:---------------|:------------:|----------:|
+|oracle profile |32.72 |42.92 |
+|cluster profile|32.73 |49.37 |
+
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413鈥�4417.
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
rename to egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
diff --git a/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
new file mode 100644
index 0000000..507ad30
--- /dev/null
+++ b/egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
@@ -0,0 +1,102 @@
+# network architecture
+frontend: multichannelfrontend
+frontend_conf:
+ fs: 16000
+ window: hann
+ n_fft: 400
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+ use_channel: 0
+
+# encoder related
+encoder: conformer
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ rel_pos_type: latest
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder related
+decoder: transformer
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# ctc related
+ctc_conf:
+ ignore_nan_grad: true
+
+# hybrid CTC/attention
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 7000
+ num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 100
+val_scheduler_criterion:
+ - valid
+ - acc
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
diff --git a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
new file mode 100644
index 0000000..47bc6bd
--- /dev/null
+++ b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
@@ -0,0 +1,131 @@
+# network architecture
+frontend: multichannelfrontend
+frontend_conf:
+ fs: 16000
+ window: hann
+ n_fft: 400
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+ use_channel: 0
+
+# encoder related
+asr_encoder: conformer
+asr_encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+spk_encoder: resnet34_diar
+spk_encoder_conf:
+ use_head_conv: true
+ batchnorm_momentum: 0.5
+ use_head_maxpool: false
+ num_nodes_pooling_layer: 256
+ layers_in_block:
+ - 3
+ - 4
+ - 6
+ - 3
+ filters_in_block:
+ - 32
+ - 64
+ - 128
+ - 256
+ pooling_type: statistic
+ num_nodes_resnet1: 256
+ num_nodes_last_layer: 256
+ batchnorm_momentum: 0.5
+
+# decoder related
+decoder: sa_decoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ asr_num_blocks: 6
+ spk_num_blocks: 3
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+# hybrid CTC/attention
+model_conf:
+ spk_weight: 0.5
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+ max_spk_num: 4
+
+ctc_conf:
+ ignore_nan_grad: true
+
+# minibatch related
+dataset_conf:
+ data_names: speech,text,profile,text_id
+ data_types: sound,text,npy,text_int
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 7000
+ num_workers: 8
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - acc
+ - max
+- - valid
+ - acc_spk
+ - max
+- - valid
+ - loss
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 8000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
similarity index 74%
rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
index c13ee42..fd76837 100755
--- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
+++ b/egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
@@ -21,6 +21,8 @@
SECONDS=0
tgt=Train #Train or Eval
+min_wav_duration=0.1
+max_wav_duration=20
log "$0 $*"
@@ -57,27 +59,24 @@
stop_stage=4
mkdir -p $far_dir
mkdir -p $near_dir
+mkdir -p data/org
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
log "stage 1:process alimeeting near dir"
find -L $near_raw_dir/audio_dir -iname "*.wav" | sort > $near_dir/wavlist
- awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
- find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
+ awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid
+ find -L $near_raw_dir/textgrid_dir -iname "*.TextGrid" > $near_dir/textgrid.flist
n1_wav=$(wc -l < $near_dir/wavlist)
n2_text=$(wc -l < $near_dir/textgrid.flist)
log near file found $n1_wav wav and $n2_text text.
- paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
-
- # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -c 1 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
- cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $near_dir/wav.scp
+ paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
- #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
- #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
+
local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
@@ -97,9 +96,7 @@
n2_text=$(wc -l < $far_dir/textgrid.flist)
log far file found $n1_wav wav and $n2_text text.
- paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
-
- cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav %s -r 16000 -b 16 -t wav - |\n", $1, $2)}' > $far_dir/wav.scp
+ paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
python local/alimeeting_process_overlap_force.py --path $far_dir \
--no-overlap false --mars True \
@@ -119,28 +116,28 @@
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- log "stage 3: finali data process"
+ log "stage 3: final data process"
local/fix_data_dir.sh $near_dir
local/fix_data_dir.sh $far_dir
- local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
- local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+ local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
+ local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
- sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
- sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
+ sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
+ sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
# remove space in text
for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
- cp data/${x}/text data/${x}/text.org
- paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
- > data/${x}/text
- rm data/${x}/text.org
+ cp data/org/${x}/text data/org/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+ > data/org/${x}/text
+ rm data/org/${x}/text.org
done
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
+ log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
cp -r $far_dir/* $far_single_speaker_dir
mv $far_single_speaker_dir/textgrid.flist $far_single_speaker_dir/textgrid_oldpath
paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
@@ -150,14 +147,15 @@
local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
./local/fix_data_dir.sh $far_single_speaker_dir
- local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+ local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
- cp data/${x}/text data/${x}/text.org
- paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
- > data/${x}/text
- rm data/${x}/text.org
+ cp data/org/${x}/text data/org/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
+ > data/org/${x}/text
+ rm data/org/${x}/text.org
done
+ rm -rf data/local
log "Successfully finished. [elapsed=${SECONDS}s]"
fi
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
rename to egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py b/egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_process_overlap_force.py
rename to egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py b/egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/alimeeting_process_textgrid.py
rename to egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
diff --git a/egs/alimeeting/sa-asr/local/apply_map.pl b/egs/alimeeting/sa_asr/local/apply_map.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/apply_map.pl
rename to egs/alimeeting/sa_asr/local/apply_map.pl
diff --git a/egs/alimeeting/sa-asr/local/combine_data.sh b/egs/alimeeting/sa_asr/local/combine_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/combine_data.sh
rename to egs/alimeeting/sa_asr/local/combine_data.sh
diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.py b/egs/alimeeting/sa_asr/local/compute_cmvn.py
new file mode 100755
index 0000000..d16563a
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/compute_cmvn.py
@@ -0,0 +1,134 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+import yaml
+from funasr.models.frontend.default import DefaultFrontend
+import torch
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="computer global cmvn",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--dim",
+ default=80,
+ type=int,
+ help="feature dimension",
+ )
+ parser.add_argument(
+ "--wav_path",
+ default=False,
+ required=True,
+ type=str,
+ help="the path of wav scps",
+ )
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ help="the config file for computing cmvn",
+ )
+ parser.add_argument(
+ "--idx",
+ default=1,
+ required=True,
+ type=int,
+ help="index",
+ )
+ return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
+
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
+ total_frames = 0
+
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+
+ with open(args.config_file) as f:
+ configs = yaml.safe_load(f)
+ frontend_configs = configs.get("frontend_conf", {})
+ num_mel_bins = frontend_configs.get("n_mels", 80)
+ frame_length = frontend_configs.get("frame_length", 25)
+ frame_shift = frontend_configs.get("frame_shift", 10)
+ window_type = frontend_configs.get("window", "hamming")
+ resample_rate = frontend_configs.get("fs", 16000)
+ n_fft = frontend_configs.get("n_fft", "400")
+ use_channel = frontend_configs.get("use_channel", None)
+ assert num_mel_bins == args.dim
+ frontend = DefaultFrontend(
+ fs=resample_rate,
+ n_fft=n_fft,
+ win_length=frame_length * 16,
+ hop_length=frame_shift * 16,
+ window=window_type,
+ n_mels=num_mel_bins,
+ use_channel=use_channel,
+ )
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ wavform, _ = torchaudio.load(wav_file)
+ fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
+ fbank = fbank.squeeze(0).numpy()
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
+
+ cmvn_info = {
+ 'mean_stats': list(mean_stats.tolist()),
+ 'var_stats': list(var_stats.tolist()),
+ 'total_frames': total_frames
+ }
+ with open(cmvn_file, 'w') as fout:
+ fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/egs/alimeeting/sa_asr/local/compute_cmvn.sh b/egs/alimeeting/sa_asr/local/compute_cmvn.sh
new file mode 100755
index 0000000..00d08d1
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/compute_cmvn.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+fbankdir=
+nj=32
+cmd=./utils/run.pl
+feats_dim=80
+config_file=
+scale=1.0
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+# shellcheck disable=SC2046
+head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
+
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
+
+logdir=${fbankdir}/cmvn/log
+$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
+ python local/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --config_file $config_file \
+ --idx JOB \
+
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
+
+echo "$0: Succeeded compute global cmvn"
diff --git a/egs/alimeeting/sa-asr/local/compute_cpcer.py b/egs/alimeeting/sa_asr/local/compute_cpcer.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/compute_cpcer.py
rename to egs/alimeeting/sa_asr/local/compute_cpcer.py
diff --git a/egs/alimeeting/sa_asr/local/convert_model.py b/egs/alimeeting/sa_asr/local/convert_model.py
new file mode 100644
index 0000000..f0f7997
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/convert_model.py
@@ -0,0 +1,29 @@
+import codecs
+import pdb
+import sys
+import torch
+
+char1 = sys.argv[1]
+char2 = sys.argv[2]
+model1 = torch.load(sys.argv[3], map_location='cpu')
+model2_path = sys.argv[4]
+
+d_new = model1
+char1_list = []
+map_list = []
+
+
+with codecs.open(char1) as f:
+ for line in f.readlines():
+ char1_list.append(line.strip())
+
+with codecs.open(char2) as f:
+ for line in f.readlines():
+ map_list.append(char1_list.index(line.strip()))
+print(map_list)
+
+for k, v in d_new.items():
+ if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
+ d_new[k] = v[map_list]
+
+torch.save(d_new, model2_path)
diff --git a/egs/alimeeting/sa-asr/local/copy_data_dir.sh b/egs/alimeeting/sa_asr/local/copy_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/copy_data_dir.sh
rename to egs/alimeeting/sa_asr/local/copy_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh b/egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
rename to egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh b/egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
rename to egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
diff --git a/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh b/egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
rename to egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
diff --git a/egs/alimeeting/sa-asr/local/data/split_data.sh b/egs/alimeeting/sa_asr/local/data/split_data.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/data/split_data.sh
rename to egs/alimeeting/sa_asr/local/data/split_data.sh
diff --git a/egs/alimeeting/sa_asr/local/download_and_untar.sh b/egs/alimeeting/sa_asr/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/alimeeting/sa_asr/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py b/egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/download_pretrained_model_from_modelscope.py
rename to egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
diff --git a/egs/alimeeting/sa-asr/local/download_xvector_model.py b/egs/alimeeting/sa_asr/local/download_xvector_model.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/download_xvector_model.py
rename to egs/alimeeting/sa_asr/local/download_xvector_model.py
diff --git a/egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py b/egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/filter_utt2spk_all_fifo.py
rename to egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
diff --git a/egs/alimeeting/sa-asr/local/fix_data_dir.sh b/egs/alimeeting/sa_asr/local/fix_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/fix_data_dir.sh
rename to egs/alimeeting/sa_asr/local/fix_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa_asr/local/format_wav_scp.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/format_wav_scp.py
rename to egs/alimeeting/sa_asr/local/format_wav_scp.py
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa_asr/local/format_wav_scp.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/format_wav_scp.sh
rename to egs/alimeeting/sa_asr/local/format_wav_scp.sh
diff --git a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
similarity index 97%
rename from egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
rename to egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
index c37abf9..859b72f 100644
--- a/egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
+++ b/egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
@@ -63,7 +63,7 @@
wav_scp_file = open(path+'/wav.scp', 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
- raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@@ -92,8 +92,8 @@
cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
meeting_map = {}
for line in raw_meeting_scp:
- meeting = line.strip().split('\t')[0]
- wav_path = line.strip().split('\t')[1]
+ meeting = line.strip().split(' ')[0]
+ wav_path = line.strip().split(' ')[1]
wav = soundfile.read(wav_path)[0]
# take the first channel
if wav.ndim == 2:
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
similarity index 94%
rename from egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
index 18286b4..2a99b2b 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
+++ b/egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
@@ -9,7 +9,7 @@
if __name__=="__main__":
path = sys.argv[1] # dump2/raw/Eval_Ali_far
raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
- raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
+ raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
raw_meeting_scp = raw_meeting_scp_file.readlines()
raw_meeting_scp_file.close()
segments_scp_file = open(raw_path + '/segments', 'r')
@@ -22,8 +22,8 @@
raw_wav_map = {}
for line in raw_meeting_scp:
- meeting = line.strip().split('\t')[0]
- wav_path = line.strip().split('\t')[1]
+ meeting = line.strip().split(' ')[0]
+ wav_path = line.strip().split(' ')[1]
raw_wav_map[meeting] = wav_path
spk_map = {}
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_nopadding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
diff --git a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
similarity index 96%
rename from egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
rename to egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
index 186f1de..ff65a1f 100644
--- a/egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
+++ b/egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
@@ -5,7 +5,7 @@
if __name__=="__main__":
- path = sys.argv[1] # dump2/raw/Train_Ali_far
+ path = sys.argv[1]
wav_scp_file = open(path+"/wav.scp", 'r')
wav_scp = wav_scp_file.readlines()
wav_scp_file.close()
@@ -29,7 +29,7 @@
line_list = line.strip().split(' ')
meeting = line_list[0].split('-')[0]
spk_id = line_list[0].split('-')[-1].split('_')[-1]
- spk = meeting+'_' + spk_id
+ spk = meeting + '_' + spk_id
global_spk_list.append(spk)
if meeting in meeting_map_tmp.keys():
meeting_map_tmp[meeting].append(spk)
diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
rename to egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
diff --git a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
similarity index 94%
rename from egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
rename to egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
index d900bb1..488344f 100755
--- a/egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
+++ b/egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
@@ -30,8 +30,7 @@
meetingid_map = {}
for line in spk2utt:
spkid = line.strip().split(" ")[0]
- meeting_id_list = spkid.split("_")[:3]
- meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
+ meeting_id = spkid.split("-")[0]
if meeting_id not in meetingid_map:
meetingid_map[meeting_id] = 1
else:
diff --git a/egs/alimeeting/sa-asr/local/process_text_id.py b/egs/alimeeting/sa_asr/local/process_text_id.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_text_id.py
rename to egs/alimeeting/sa_asr/local/process_text_id.py
diff --git a/egs/alimeeting/sa-asr/local/process_text_spk_merge.py b/egs/alimeeting/sa_asr/local/process_text_spk_merge.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_text_spk_merge.py
rename to egs/alimeeting/sa_asr/local/process_text_spk_merge.py
diff --git a/egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py b/egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
similarity index 100%
rename from egs/alimeeting/sa-asr/local/process_textgrid_to_single_speaker_wav.py
rename to egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
diff --git a/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
rename to egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
diff --git a/egs/alimeeting/sa-asr/local/text_format.pl b/egs/alimeeting/sa_asr/local/text_format.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/text_format.pl
rename to egs/alimeeting/sa_asr/local/text_format.pl
diff --git a/egs/alimeeting/sa-asr/local/text_normalize.pl b/egs/alimeeting/sa_asr/local/text_normalize.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/text_normalize.pl
rename to egs/alimeeting/sa_asr/local/text_normalize.pl
diff --git a/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
rename to egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
diff --git a/egs/alimeeting/sa-asr/local/validate_data_dir.sh b/egs/alimeeting/sa_asr/local/validate_data_dir.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/local/validate_data_dir.sh
rename to egs/alimeeting/sa_asr/local/validate_data_dir.sh
diff --git a/egs/alimeeting/sa-asr/local/validate_text.pl b/egs/alimeeting/sa_asr/local/validate_text.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/local/validate_text.pl
rename to egs/alimeeting/sa_asr/local/validate_text.pl
diff --git a/egs/alimeeting/sa_asr/path.sh b/egs/alimeeting/sa_asr/path.sh
new file mode 100755
index 0000000..83ae507
--- /dev/null
+++ b/egs/alimeeting/sa_asr/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
+export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
diff --git a/egs/alimeeting/sa_asr/run.sh b/egs/alimeeting/sa_asr/run.sh
new file mode 100755
index 0000000..43d0da1
--- /dev/null
+++ b/egs/alimeeting/sa_asr/run.sh
@@ -0,0 +1,435 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="6,7"
+gpu_num=2
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=8
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="data" #feature output dictionary
+exp_dir="exp"
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="1.0"
+min_wav_duration=0.1
+max_wav_duration=20
+profile_modes="cluster oracle"
+stage=7
+stop_stage=7
+
+# feature configuration
+feats_dim=80
+nj=32
+
+# data
+raw_data=
+data_url=
+
+# exp tag
+tag=""
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=Train_Ali_far
+valid_set=Eval_Ali_far
+test_sets="Test_Ali_far Eval_Ali_far"
+test_2023="Test_2023_Ali_far_release"
+
+asr_config=conf/train_asr_conformer.yaml
+sa_asr_config=conf/train_sa_asr_conformer.yaml
+asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+inference_config=conf/decode_asr_rnn.yaml
+inference_sa_asr_model=valid.acc_spk.ave.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ ./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ ./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ ./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
+ # remove long/short data
+ for x in ${train_set} ${valid_set} ${test_sets}; do
+ cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
+ rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+ "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+ _min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
+ _max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
+ <"${feats_dir}/${x}/utt2num_samples" \
+ awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
+ >"${feats_dir}/${x}/utt2num_samples_rmls"
+ mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
+ <"${feats_dir}/${x}/wav.scp" \
+ utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples" \
+ >"${feats_dir}/${x}/wav.scp_rmls"
+ mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
+ <"${feats_dir}/${x}/text" \
+ awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
+ mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
+ local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
+ <"${feats_dir}/${x}/utt2spk_all_fifo" \
+ utils/filter_scp.pl "${feats_dir}/${x}/text" \
+ >"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
+ mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
+ done
+ for x in ${test_2023}; do
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ --audio-format wav --segments ${feats_dir}/org/${x}/segments \
+ "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
+ cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
+ paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
+ cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Speaker profile and CMVN Generation"
+
+ mkdir -p "profile_log"
+ for x in "${train_set}" "${valid_set}" "${test_sets}"; do
+ # generate text_id spk2id
+ python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
+ echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
+ # generate text_id_train for sot
+ python local/process_text_id.py ${feats_dir}/${x}
+ echo "Successfully generate ${feats_dir}/${x}/text_id_train"
+ # generate oracle_embedding from single-speaker audio segment
+ echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
+ python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
+ echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
+ # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
+ if [ "${x}" = "${train_set}" ]; then
+ python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
+ echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
+ else
+ python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
+ echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
+ fi
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
+ echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
+ python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+ echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+ fi
+ # compute CMVN
+ if [ "${x}" = "${train_set}" ]; then
+ local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
+ fi
+ done
+
+ for x in "${test_2023}"; do
+ # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
+ python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
+ echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
+ done
+fi
+
+token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "Stage 4: ASR Training"
+ asr_exp=${exp_dir}/${asr_model_dir}
+ mkdir -p ${asr_exp}
+ mkdir -p ${asr_exp}/log
+ INIT_FILE=${asr_exp}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ # i=0
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --model asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --token_type char \
+ --token_list $token_list \
+ --data_dir ${feats_dir} \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/${asr_model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+
+fi
+
+
+
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "SA-ASR training"
+ asr_exp=${exp_dir}/${asr_model_dir}
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ mkdir -p ${sa_asr_exp}
+ mkdir -p ${sa_asr_exp}/log
+ INIT_FILE=${sa_asr_exp}/ddp_init
+ if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
+ ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
+ ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
+ fi
+
+ if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
+ # download xvector extractor model file
+ python local/download_xvector_model.py ${exp_dir}
+ echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
+ fi
+
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $ngpu; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --model sa_asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --unused_parameters true \
+ --token_type char \
+ --resume true \
+ --token_list $token_list \
+ --data_dir ${feats_dir} \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text,profile.scp,text_id_train" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
+ --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder" \
+ --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense" \
+ --output_dir ${exp_dir}/${sa_asr_model_dir} \
+ --config $sa_asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "stage 6: Inference test sets"
+ for x in ${test_sets}; do
+ for profile_mode in ${profile_modes}; do
+ echo "decoding ${x} with ${profile_mode} profile"
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${x}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ if [ $profile_mode = "oracle" ]; then
+ profile_scp=${profile_mode}_profile_nopadding.scp
+ else
+ profile_scp=${profile_mode}_profile_infer.scp
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --nbest 1 \
+ --gpuid_list ${gpuid_list} \
+ --allow_variable_data_keys true \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ for f in token token_int score text text_id; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
+ sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
+ python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
+
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+
+ python local/process_text_spk_merge.py ${_dir}
+ python local/process_text_spk_merge.py ${_data}
+
+ python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
+ tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
+ cat ${_dir}/text.cpcer.txt
+ done
+ done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+ echo "stage 7: Inference test 2023"
+ for x in ${test_2023}; do
+ sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${x}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --mc True \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --nbest 1 \
+ --gpuid_list ${gpuid_list} \
+ --allow_variable_data_keys true \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
+ --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${sa_asr_exp}"/config.yaml \
+ --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode sa_asr \
+ ${_opts}
+
+ for f in token token_int score text text_id; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+
+ python local/process_text_spk_merge.py ${_dir}
+
+ done
+fi
+
+
diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa_asr/utils
similarity index 100%
rename from egs/alimeeting/sa-asr/utils
rename to egs/alimeeting/sa_asr/utils
diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa_asr_deprecated/README.md
similarity index 100%
rename from egs/alimeeting/sa-asr/README.md
rename to egs/alimeeting/sa_asr_deprecated/README.md
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa_asr_deprecated/asr_local.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/asr_local.sh
rename to egs/alimeeting/sa_asr_deprecated/asr_local.sh
diff --git a/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh b/egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
rename to egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh
diff --git a/egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml b/egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
similarity index 100%
copy from egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml
copy to egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
rename to egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml
similarity index 100%
rename from egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
rename to egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml
diff --git a/egs/alimeeting/sa_asr_deprecated/local b/egs/alimeeting/sa_asr_deprecated/local
new file mode 120000
index 0000000..2ef6217
--- /dev/null
+++ b/egs/alimeeting/sa_asr_deprecated/local
@@ -0,0 +1 @@
+../sa_asr/local/
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa_asr_deprecated/path.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/path.sh
rename to egs/alimeeting/sa_asr_deprecated/path.sh
diff --git a/egs/alimeeting/sa-asr/run.sh b/egs/alimeeting/sa_asr_deprecated/run.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/run.sh
rename to egs/alimeeting/sa_asr_deprecated/run.sh
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
rename to egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh
diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa_asr_deprecated/utils
similarity index 100%
copy from egs/alimeeting/sa-asr/utils
copy to egs/alimeeting/sa_asr_deprecated/utils
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 140b424..c722ebc 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1636,8 +1636,10 @@
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend == 'wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ from funasr.tasks.sa_asr import frontend_choices
+ if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 367b9a8..656a965 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -619,7 +619,12 @@
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
- batch_size_token_ms = batch_size_token * 60
+
+ batch_size_token_ms = batch_size_token*60
+ if speech2text.device == "cpu":
+ batch_size_token_ms = 0
+ batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+
batch_size_token_ms_cum = 0
beg_idx = 0
for j, _ in enumerate(range(0, n)):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index f4fc0a7..1dc3fb5 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -301,7 +301,7 @@
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index d4a954c..200395d 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -6,7 +6,6 @@
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
@@ -20,17 +19,23 @@
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \
- ContextualParaformer
+
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+
+from funasr.models.e2e_sa_asr import SAASRModel
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -93,6 +98,8 @@
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
+ sa_asr=SAASRModel,
+
),
default="asr",
)
@@ -109,6 +116,27 @@
chunk_conformer=ConformerChunkEncoder,
),
default="rnn",
+)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ default="rnn",
+)
+
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
)
encoder_choices2 = ClassChoices(
"encoder2",
@@ -134,6 +162,7 @@
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
),
default="rnn",
)
@@ -225,6 +254,10 @@
rnnt_decoder_choices,
# --joint_network and --joint_network_conf
joint_network_choices,
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
]
@@ -247,7 +280,7 @@
# frontend
if hasattr(args, "input_size") and args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
- if args.frontend == 'wav_frontend':
+ if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
@@ -425,6 +458,33 @@
joint_network=joint_network,
**args.model_conf,
)
+ elif args.model == "sa_asr":
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder.output_size(),
+ **args.decoder_conf,
+ )
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+ )
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index 8304607..e209d51 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -40,7 +40,7 @@
yield
-class ESPnetASRModel(FunASRModel):
+class SAASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -51,10 +51,8 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
asr_encoder: AbsEncoder,
spk_encoder: torch.nn.Module,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
spk_weight: float = 0.5,
@@ -89,8 +87,6 @@
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
self.asr_encoder = asr_encoder
self.spk_encoder = spk_encoder
@@ -293,10 +289,6 @@
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +309,6 @@
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
else:
encoder_out_spk=encoder_out_spk_ori
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -337,7 +324,7 @@
)
if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
+ return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
return encoder_out, encoder_out_lens, encoder_out_spk
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 19994f0..6718f3f 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -2,7 +2,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
-
+import logging
import humanfriendly
import numpy as np
import torch
@@ -14,6 +14,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.modules.nets_utils import make_pad_mask
class DefaultFrontend(AbsFrontend):
@@ -137,8 +138,6 @@
return input_stft, feats_lens
-
-
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -147,9 +146,9 @@
def __init__(
self,
fs: Union[int, str] = 16000,
- n_fft: int = 512,
- win_length: int = None,
- hop_length: int = 128,
+ n_fft: int = 400,
+ frame_length: int = 25,
+ frame_shift: int = 10,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
@@ -160,10 +159,10 @@
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
- frame_length: int = None,
- frame_shift: int = None,
- lfr_m: int = None,
- lfr_n: int = None,
+ use_channel: int = None,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ cmvn_file: str = None
):
assert check_argument_types()
super().__init__()
@@ -172,13 +171,14 @@
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
- self.hop_length = hop_length
+ self.win_length = frame_length * 16
+ self.hop_length = frame_shift * 16
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
- win_length=win_length,
- hop_length=hop_length,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
center=center,
window=window,
normalized=normalized,
@@ -202,7 +202,17 @@
htk=htk,
)
self.n_mels = n_mels
- self.frontend_type = "multichannelfrontend"
+ self.frontend_type = "default"
+ self.use_channel = use_channel
+ if self.use_channel is not None:
+ logging.info("use the channel %d" % (self.use_channel))
+ else:
+ logging.info("random select channel")
+ self.cmvn_file = cmvn_file
+ if self.cmvn_file is not None:
+ mean, std = self._load_cmvn(self.cmvn_file)
+ self.register_buffer("mean", torch.from_numpy(mean))
+ self.register_buffer("std", torch.from_numpy(std))
def output_size(self) -> int:
return self.n_mels
@@ -215,16 +225,29 @@
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
- if isinstance(input, ComplexTensor):
- input_stft = input
- else:
- input_stft = ComplexTensor(input[..., 0], input[..., 1])
+ input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
+
+ # 3. [Multi channel case]: Select a channel
+ if input_stft.dim() == 4:
+ # h: (B, T, C, F) -> h: (B, T, F)
+ if self.training:
+ if self.use_channel is not None:
+ input_stft = input_stft[:, :, self.use_channel, :]
+
+ else:
+ # Select 1ch randomly
+ ch = np.random.randint(input_stft.size(2))
+ input_stft = input_stft[:, :, ch, :]
+ else:
+ # Use the first channel
+ input_stft = input_stft[:, :, 0, :]
+
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
@@ -233,18 +256,27 @@
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
- bt = input_feats.size(0)
- if input_feats.dim() ==4:
- channel_size = input_feats.size(2)
- # batch * channel * T * D
- #pdb.set_trace()
- input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
- # input_feats = input_feats.transpose(1,2)
- # batch * channel
- feats_lens = feats_lens.repeat(1,channel_size).squeeze()
- else:
- channel_size = 1
- return input_feats, feats_lens, channel_size
+
+ # 6. Apply CMVN
+ if self.cmvn_file is not None:
+ if feats_lens is None:
+ feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
+ self.mean = self.mean.to(input_feats.device, input_feats.dtype)
+ self.std = self.std.to(input_feats.device, input_feats.dtype)
+ mask = make_pad_mask(feats_lens, input_feats, 1)
+
+ if input_feats.requires_grad:
+ input_feats = input_feats + self.mean
+ else:
+ input_feats += self.mean
+ if input_feats.requires_grad:
+ input_feats = input_feats.masked_fill(mask, 0.0)
+ else:
+ input_feats.masked_fill_(mask, 0.0)
+
+ input_feats *= self.std
+
+ return input_feats, feats_lens
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
@@ -258,4 +290,27 @@
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
- return input_stft, feats_lens
\ No newline at end of file
+ return input_stft, feats_lens
+
+ def _load_cmvn(self, cmvn_file):
+ with open(cmvn_file, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ means_list = []
+ vars_list = []
+ for i in range(len(lines)):
+ line_item = lines[i].split()
+ if line_item[0] == '<AddShift>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ add_shift_line = line_item[3:(len(line_item) - 1)]
+ means_list = list(add_shift_line)
+ continue
+ elif line_item[0] == '<Rescale>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ rescale_line = line_item[3:(len(line_item) - 1)]
+ vars_list = list(rescale_line)
+ continue
+ means = np.array(means_list).astype(np.float)
+ vars = np.array(vars_list).astype(np.float)
+ return means, vars
\ No newline at end of file
diff --git a/funasr/runtime/java/FunasrWsClient.java b/funasr/runtime/java/FunasrWsClient.java
new file mode 100644
index 0000000..ec55c94
--- /dev/null
+++ b/funasr/runtime/java/FunasrWsClient.java
@@ -0,0 +1,344 @@
+//
+// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+// Reserved. MIT License (https://opensource.org/licenses/MIT)
+//
+/*
+ * // 2022-2023 by zhaomingwork@qq.com
+ */
+// java FunasrWsClient
+// usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
+// [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
+package websocket;
+
+import java.io.*;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.*;
+import java.util.Map;
+import net.sourceforge.argparse4j.ArgumentParsers;
+import net.sourceforge.argparse4j.inf.ArgumentParser;
+import net.sourceforge.argparse4j.inf.ArgumentParserException;
+import net.sourceforge.argparse4j.inf.Namespace;
+import org.java_websocket.client.WebSocketClient;
+import org.java_websocket.drafts.Draft;
+import org.java_websocket.handshake.ServerHandshake;
+import org.json.simple.JSONArray;
+import org.json.simple.JSONObject;
+import org.json.simple.parser.JSONParser;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** This example demonstrates how to connect to websocket server. */
+public class FunasrWsClient extends WebSocketClient {
+
+ public class RecWavThread extends Thread {
+ private FunasrWsClient funasrClient;
+
+ public RecWavThread(FunasrWsClient funasrClient) {
+ this.funasrClient = funasrClient;
+ }
+
+ public void run() {
+ this.funasrClient.recWav();
+ }
+ }
+
+ private static final Logger logger = LoggerFactory.getLogger(FunasrWsClient.class);
+
+ public FunasrWsClient(URI serverUri, Draft draft) {
+ super(serverUri, draft);
+ }
+
+ public FunasrWsClient(URI serverURI) {
+ super(serverURI);
+ }
+
+ public FunasrWsClient(URI serverUri, Map<String, String> httpHeaders) {
+ super(serverUri, httpHeaders);
+ }
+
+ public void getSslContext(String keyfile, String certfile) {
+ // TODO
+ return;
+ }
+
+ // send json at first time
+ public void sendJson(
+ String mode, String strChunkSize, int chunkInterval, String wavName, boolean isSpeaking) {
+ try {
+
+ JSONObject obj = new JSONObject();
+ obj.put("mode", mode);
+ JSONArray array = new JSONArray();
+ String[] chunkList = strChunkSize.split(",");
+ for (int i = 0; i < chunkList.length; i++) {
+ array.add(Integer.valueOf(chunkList[i].trim()));
+ }
+
+ obj.put("chunk_size", array);
+ obj.put("chunk_interval", new Integer(chunkInterval));
+ obj.put("wav_name", wavName);
+ if (isSpeaking) {
+ obj.put("is_speaking", new Boolean(true));
+ } else {
+ obj.put("is_speaking", new Boolean(false));
+ }
+ logger.info("sendJson: " + obj);
+ // return;
+
+ send(obj.toString());
+
+ return;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // send json at end of wav
+ public void sendEof() {
+ try {
+ JSONObject obj = new JSONObject();
+
+ obj.put("is_speaking", new Boolean(false));
+
+ logger.info("sendEof: " + obj);
+ // return;
+
+ send(obj.toString());
+ iseof = true;
+ return;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ // function for rec wav file
+ public void recWav() {
+ sendJson(mode, strChunkSize, chunkInterval, wavName, true);
+ File file = new File(FunasrWsClient.wavPath);
+
+ int chunkSize = sendChunkSize;
+ byte[] bytes = new byte[chunkSize];
+
+ int readSize = 0;
+ try (FileInputStream fis = new FileInputStream(file)) {
+ if (FunasrWsClient.wavPath.endsWith(".wav")) {
+ fis.read(bytes, 0, 44); //skip first 44 wav header
+ }
+ readSize = fis.read(bytes, 0, chunkSize);
+ while (readSize > 0) {
+ // send when it is chunk size
+ if (readSize == chunkSize) {
+ send(bytes); // send buf to server
+
+ } else {
+ // send when at last or not is chunk size
+ byte[] tmpBytes = new byte[readSize];
+ for (int i = 0; i < readSize; i++) {
+ tmpBytes[i] = bytes[i];
+ }
+ send(tmpBytes);
+ }
+ // if not in offline mode, we simulate online stream by sleep
+ if (!mode.equals("offline")) {
+ Thread.sleep(Integer.valueOf(chunkSize / 32));
+ }
+
+ readSize = fis.read(bytes, 0, chunkSize);
+ }
+
+ if (!mode.equals("offline")) {
+ // if not offline, we send eof and wait for 3 seconds to close
+ Thread.sleep(2000);
+ sendEof();
+ Thread.sleep(3000);
+ close();
+ } else {
+ // if offline, just send eof
+ sendEof();
+ }
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onOpen(ServerHandshake handshakedata) {
+
+ RecWavThread thread = new RecWavThread(this);
+ thread.start();
+ }
+
+ @Override
+ public void onMessage(String message) {
+ JSONObject jsonObject = new JSONObject();
+ JSONParser jsonParser = new JSONParser();
+ logger.info("received: " + message);
+ try {
+ jsonObject = (JSONObject) jsonParser.parse(message);
+ logger.info("text: " + jsonObject.get("text"));
+ } catch (org.json.simple.parser.ParseException e) {
+ e.printStackTrace();
+ }
+ if (iseof && mode.equals("offline")) {
+ close();
+ }
+ }
+
+ @Override
+ public void onClose(int code, String reason, boolean remote) {
+
+ logger.info(
+ "Connection closed by "
+ + (remote ? "remote peer" : "us")
+ + " Code: "
+ + code
+ + " Reason: "
+ + reason);
+ }
+
+ @Override
+ public void onError(Exception ex) {
+ logger.info("ex: " + ex);
+ ex.printStackTrace();
+ // if the error is fatal then onClose will be called additionally
+ }
+
+ private boolean iseof = false;
+ public static String wavPath;
+ static String mode = "online";
+ static String strChunkSize = "5,10,5";
+ static int chunkInterval = 10;
+ static int sendChunkSize = 1920;
+
+ String wavName = "javatest";
+
+ public static void main(String[] args) throws URISyntaxException {
+ ArgumentParser parser = ArgumentParsers.newArgumentParser("ws client").defaultHelp(true);
+ parser
+ .addArgument("--port")
+ .help("Port on which to listen.")
+ .setDefault("8889")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--host")
+ .help("the IP address of server.")
+ .setDefault("127.0.0.1")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--audio_in")
+ .help("wav path for decoding.")
+ .setDefault("asr_example.wav")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--num_threads")
+ .help("num of threads for test.")
+ .setDefault(1)
+ .type(Integer.class)
+ .required(false);
+ parser
+ .addArgument("--chunk_size")
+ .help("chunk size for asr.")
+ .setDefault("5, 10, 5")
+ .type(String.class)
+ .required(false);
+ parser
+ .addArgument("--chunk_interval")
+ .help("chunk for asr.")
+ .setDefault(10)
+ .type(Integer.class)
+ .required(false);
+
+ parser
+ .addArgument("--mode")
+ .help("mode for asr.")
+ .setDefault("offline")
+ .type(String.class)
+ .required(false);
+ String srvIp = "";
+ String srvPort = "";
+ String wavPath = "";
+ int numThreads = 1;
+ String chunk_size = "";
+ int chunk_interval = 10;
+ String strmode = "offline";
+
+ try {
+ Namespace ns = parser.parseArgs(args);
+ srvIp = ns.get("host");
+ srvPort = ns.get("port");
+ wavPath = ns.get("audio_in");
+ numThreads = ns.get("num_threads");
+ chunk_size = ns.get("chunk_size");
+ chunk_interval = ns.get("chunk_interval");
+ strmode = ns.get("mode");
+ System.out.println(srvPort);
+
+ } catch (ArgumentParserException ex) {
+ ex.getParser().handleError(ex);
+ return;
+ }
+
+ FunasrWsClient.strChunkSize = chunk_size;
+ FunasrWsClient.chunkInterval = chunk_interval;
+ FunasrWsClient.wavPath = wavPath;
+ FunasrWsClient.mode = strmode;
+ System.out.println(
+ "serIp="
+ + srvIp
+ + ",srvPort="
+ + srvPort
+ + ",wavPath="
+ + wavPath
+ + ",strChunkSize"
+ + strChunkSize);
+
+ class ClientThread implements Runnable {
+
+ String srvIp;
+ String srvPort;
+
+ ClientThread(String srvIp, String srvPort, String wavPath) {
+ this.srvIp = srvIp;
+ this.srvPort = srvPort;
+ }
+
+ public void run() {
+ try {
+
+ int RATE = 16000;
+ String[] chunkList = strChunkSize.split(",");
+ int int_chunk_size = 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval;
+ int CHUNK = Integer.valueOf(RATE / 1000 * int_chunk_size);
+ int stride =
+ Integer.valueOf(
+ 60 * Integer.valueOf(chunkList[1].trim()) / chunkInterval / 1000 * 16000 * 2);
+ System.out.println("chunk_size:" + String.valueOf(int_chunk_size));
+ System.out.println("CHUNK:" + CHUNK);
+ System.out.println("stride:" + String.valueOf(stride));
+ FunasrWsClient.sendChunkSize = CHUNK * 2;
+
+ String wsAddress = "ws://" + srvIp + ":" + srvPort;
+
+ FunasrWsClient c = new FunasrWsClient(new URI(wsAddress));
+
+ c.connect();
+
+ System.out.println("wsAddress:" + wsAddress);
+ } catch (Exception e) {
+ e.printStackTrace();
+ System.out.println("e:" + e);
+ }
+ }
+ }
+ for (int i = 0; i < numThreads; i++) {
+ System.out.println("Thread1 is running...");
+ Thread t = new Thread(new ClientThread(srvIp, srvPort, wavPath));
+ t.start();
+ }
+ }
+}
diff --git a/funasr/runtime/java/Makefile b/funasr/runtime/java/Makefile
new file mode 100644
index 0000000..9a70ca5
--- /dev/null
+++ b/funasr/runtime/java/Makefile
@@ -0,0 +1,76 @@
+
+ENTRY_POINT = ./
+
+
+
+
+WEBSOCKET_DIR:= ./
+WEBSOCKET_FILES = \
+ $(WEBSOCKET_DIR)/FunasrWsClient.java \
+
+
+
+LIB_BUILD_DIR = ./lib
+
+
+
+
+JAVAC = javac
+
+BUILD_DIR = build
+
+
+RUNJFLAGS = -Dfile.encoding=utf-8
+
+
+vpath %.class $(BUILD_DIR)
+vpath %.java src
+
+
+
+
+rebuild: clean all
+
+.PHONY: clean run downjar
+
+downjar:
+ wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
+ wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
+ #wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
+ wget https://repo1.maven.org/maven2/org/java-websocket/Java-WebSocket/1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
+ wget https://storage.googleapis.com/google-code-archive-downloads/v2/code.google.com/json-simple/json-simple-1.1.1.jar -P ./lib/
+ wget https://github.com/argparse4j/argparse4j/releases/download/argparse4j-0.9.0/argparse4j-0.9.0.jar -P ./lib/
+ rm -frv build
+ mkdir build
+clean:
+ rm -frv $(BUILD_DIR)/*
+ rm -frv $(LIB_BUILD_DIR)/*
+ mkdir -p $(BUILD_DIR)
+ mkdir -p ./lib
+
+
+
+
+
+
+runclient:
+ java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar $(RUNJFLAGS) websocket.FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
+
+
+
+buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
+
+
+%.class: %.java
+
+ $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:lib/json-simple-1.1.1.jar:lib/argparse4j-0.9.0.jar -d $(BUILD_DIR) -encoding UTF-8 $<
+
+packjar:
+ jar cvfe lib/funasrclient.jar . -C $(BUILD_DIR) .
+
+all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
+
+
+
+
+
diff --git a/funasr/runtime/java/readme.md b/funasr/runtime/java/readme.md
new file mode 100644
index 0000000..406a21a
--- /dev/null
+++ b/funasr/runtime/java/readme.md
@@ -0,0 +1,66 @@
+# Client for java websocket example
+
+
+
+## Building for Linux/Unix
+
+### install java environment
+```shell
+# in ubuntu
+apt-get install openjdk-11-jdk
+```
+
+
+
+### Build and run by make
+
+
+```shell
+cd funasr/runtime/java
+# download java lib
+make downjar
+# compile
+make buildwebsocket
+# run client
+make runclient
+
+```
+
+## Run java websocket client by shell
+
+```shell
+# full command refer to Makefile runclient
+usage: FunasrWsClient [-h] [--port PORT] [--host HOST] [--audio_in AUDIO_IN] [--num_threads NUM_THREADS]
+ [--chunk_size CHUNK_SIZE] [--chunk_interval CHUNK_INTERVAL] [--mode MODE]
+
+Where:
+ --host <string>
+ (required) server-ip
+
+ --port <int>
+ (required) port
+
+ --audio_in <string>
+ (required) the wav or pcm file path
+
+ --num_threads <int>
+ thread number for test
+
+ --mode
+ asr mode, support "offline" "online" "2pass"
+
+
+
+example:
+FunasrWsClient --host localhost --port 8889 --audio_in ./asr_example.wav --num_threads 1 --mode 2pass
+
+result json, example like:
+{"mode":"offline","text":"娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�","wav_name":"javatest"}
+```
+
+
+## Acknowledge
+1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
+2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/java-ws-client-support/funasr/runtime/java) for contributing the java websocket client example.
+
+
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index d2692ce..a4ee7f7 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -65,7 +65,7 @@
n_total_length += snippet_time;
FunASRFreeResult(result);
}else{
- LOG(ERROR) << ("No return data!\n");
+ LOG(ERROR) << wav_ids[i] << (": No return data!\n");
}
}
{
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
index 58eec25..2ee4114 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -18,6 +18,7 @@
try{
m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << punc_model;
}
catch (std::exception const &e) {
LOG(ERROR) << "Error when load punc onnx model: " << e.what();
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 1957a12..b605fff 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -33,6 +33,7 @@
try {
m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+ LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
exit(0);
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 64e363f..246d678 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.1.0'
+VERSION_NUM = '0.1.1'
setuptools.setup(
name=MODULE_NAME,
@@ -31,7 +31,7 @@
"onnxruntime>=1.7.0",
"scipy",
"numpy>=1.19.3",
- "typeguard",
+ "typeguard==2.13.3",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"funasr",
diff --git a/funasr/runtime/ssl_key/readme.md b/funasr/runtime/ssl_key/readme.md
index a5989e6..8a48dd3 100644
--- a/funasr/runtime/ssl_key/readme.md
+++ b/funasr/runtime/ssl_key/readme.md
@@ -3,7 +3,7 @@
```shell
### 1) Generate a private key
-openssl genrsa -des3 -out server.key 1024
+openssl genrsa -des3 -out server.key 2048
### 2) Generate a csr file
openssl req -new -key server.key -out server.csr
@@ -14,4 +14,4 @@
### 4) Generated a crt file, valid for 1 year
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
-```
\ No newline at end of file
+```
diff --git a/funasr/runtime/ssl_key/server.crt b/funasr/runtime/ssl_key/server.crt
index 808b73e..5a5079d 100644
--- a/funasr/runtime/ssl_key/server.crt
+++ b/funasr/runtime/ssl_key/server.crt
@@ -1,15 +1,21 @@
-----BEGIN CERTIFICATE-----
-MIICSDCCAbECFCObiVAMkMlCGmMDGDFx5Nx3XYvOMA0GCSqGSIb3DQEBCwUAMGMx
-CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdCZWlqaW5nMRAwDgYDVQQHDAdCZWlqaW5n
-MRAwDgYDVQQKDAdhbGliYWJhMQwwCgYDVQQLDANhc3IxEDAOBgNVBAMMB2FsaWJh
-YmEwHhcNMjMwNTEyMTQzNjAxWhcNMjQwNTExMTQzNjAxWjBjMQswCQYDVQQGEwJD
-TjEQMA4GA1UECAwHQmVpamluZzEQMA4GA1UEBwwHQmVpamluZzEQMA4GA1UECgwH
-YWxpYmFiYTEMMAoGA1UECwwDYXNyMRAwDgYDVQQDDAdhbGliYWJhMIGfMA0GCSqG
-SIb3DQEBAQUAA4GNADCBiQKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETS
-NBrrRfjbBv6ucAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc
-38+VLo9UjjsoOeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdL
-mV/toQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAB5KNWF1XIIYD1geMsyT6/ZRnGNA
-dmeUyMcwYvIlQG3boSipNk/JI4W5fFOg1O2sAqflYHmwZfmasAQsC2e5bSzHZ+PB
-uMJhKYxfj81p175GumHTw5Lbp2CvFSLrnuVB0ThRdcCqEh1MDt0D3QBuBr/ZKgGS
-hXtozVCgkSJzX6uD
+MIIDhTCCAm0CFGB0Po2IZ0hESavFpcSGRNb9xrNXMA0GCSqGSIb3DQEBCwUAMH8x
+CzAJBgNVBAYTAkNOMRAwDgYDVQQIDAdiZWlqaW5nMRAwDgYDVQQHDAdiZWlqaW5n
+MRAwDgYDVQQKDAdhbGliYWJhMRAwDgYDVQQLDAdhbGliYWJhMRAwDgYDVQQDDAdh
+bGliYWJhMRYwFAYJKoZIhvcNAQkBFgdhbGliYWJhMB4XDTIzMDYxODA2NTcxM1oX
+DTI0MDYxNzA2NTcxM1owfzELMAkGA1UEBhMCQ04xEDAOBgNVBAgMB2JlaWppbmcx
+EDAOBgNVBAcMB2JlaWppbmcxEDAOBgNVBAoMB2FsaWJhYmExEDAOBgNVBAsMB2Fs
+aWJhYmExEDAOBgNVBAMMB2FsaWJhYmExFjAUBgkqhkiG9w0BCQEWB2FsaWJhYmEw
+ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDH9Np1oBunQKMt5M/nU2nD
+qVHojXwKKwyiK9DSeGikKwArH2S9NUZNu5RDg46u0iWmT+Vz+toQhkJnfatOVskW
+f2bsI54n5eOvmoWOKDXYm2MscvjkuNiYRbqzgUuP9ZSx8k3uyRs++wvmwIoU+PV1
+EYFcjk1P2jUGUvKaUlmIDsjs1wOMIbKO6I0UX20FNKlGWacqMR/Dx2ltmGKT1Kaz
+Y335lor0bcfQtH542rGS7PDz6JMRNjFT1VFcmnrjRElf4STbaOiIfOjMVZ/9O8Hr
+LFItyvkb01Mt7O0jhAXHuE1l/8Y0N3MCYkELG9mQA0BYCFHY0FLuJrGoU03b8KWj
+AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEjC9jB1WZe2ki2JgCS+eAMFsFegiNEz
+D0klVB3kiCPK0g7DCxvfWR6kAgEynxRxVX6TN9QcLr4paZItC1Fu2gUMTteNqEuc
+dcixJdu9jumuUMBlAKgL5Yyk3alSErsn9ZVF/Q8Kx5arMO/TW3Ulsd8SWQL5C/vq
+Fe0SRhpKKoADPfl8MT/XMfB/MwNxVhYDSHzJ1EiN8O5ce6q2tTdi1mlGquzNxhjC
+7Q0F36V1HksfzolrlRWRKYP16isnaKUdFfeAzaJsYw33o6VRbk6fo2fTQDHS0wOs
+Q48Moc5UxKMLaMMCqLPpWu0TZse+kIw1nTWXk7yJtK0HK5PN3rTocEw=
-----END CERTIFICATE-----
diff --git a/funasr/runtime/ssl_key/server.key b/funasr/runtime/ssl_key/server.key
index aac8b26..8efdcb8 100644
--- a/funasr/runtime/ssl_key/server.key
+++ b/funasr/runtime/ssl_key/server.key
@@ -1,15 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
-MIICXQIBAAKBgQDEINLLMasJtJQPoesCfcwJsjiUkx3hLnoUyETSNBrrRfjbBv6u
-cAgZIF+/V15IfJZR6u2ULpJN0wUg8xNQReu4kdpjSdNGuQ0aoWbc38+VLo9Ujjso
-OeoeCro6b0u+GosPoEuI4t7Ky09zw+FBibD95daJ3GDY1DGCbDdLmV/toQIDAQAB
-AoGARpA0pwygp+ZDWvh7kDLoZRitCK+BkZHiNHX1ZNeAU+Oh7FOw79u43ilqqXHq
-pxPEFYb7oVO8Kanhb4BlE32EmApBlvhd3SW07kn0dS7WVGsTvPFwKKpF88W8E+pc
-2i8At5tr2O1DZhvqNdIN7r8FRrGQ/Hpm3ItypUdz2lZnMwECQQD3dILOMJ84O2JE
-NxUwk8iOYefMJftQUO57Gm7XBVke/i3r9uajSqB2xmOvUaSyaHoJfx/mmfgfxYcD
-M+Re6mERAkEAyuaV5+eD82eG2I8PgxJ2p5SOb1x5F5qpb4KuKAlfHEkdolttMwN3
-7vl1ZWUZLVu2rHnUmvbYV2gkQO1os7/DkQJBAIDYfbN2xbC12vjB5ZqhmG/qspMt
-w6mSOlqG7OewtTLaDncq2/RySxMNQaJr1GHA3KpNMwMTcIq6gw472tFBIMECQF0z
-fjiASEROkcp4LI/ws0BXJPZSa+1DxgDK7mTFqUK88zfY91gvh6/mNt7UibQkJM0l
-SVvFd6ru03hflXC77YECQQDDQrB9ApwVOMGQw+pwbxn9p8tPYVi3oBiUfYgd1RDO
-uhcRgxv7gT4BSiyI4nFBMCYyI28azTLlUiJhMr9MNUpB
+MIIEowIBAAKCAQEAx/TadaAbp0CjLeTP51Npw6lR6I18CisMoivQ0nhopCsAKx9k
+vTVGTbuUQ4OOrtIlpk/lc/raEIZCZ32rTlbJFn9m7COeJ+Xjr5qFjig12JtjLHL4
+5LjYmEW6s4FLj/WUsfJN7skbPvsL5sCKFPj1dRGBXI5NT9o1BlLymlJZiA7I7NcD
+jCGyjuiNFF9tBTSpRlmnKjEfw8dpbZhik9Sms2N9+ZaK9G3H0LR+eNqxkuzw8+iT
+ETYxU9VRXJp640RJX+Ek22joiHzozFWf/TvB6yxSLcr5G9NTLeztI4QFx7hNZf/G
+NDdzAmJBCxvZkANAWAhR2NBS7iaxqFNN2/ClowIDAQABAoIBAQC1/STX6eFBWJMs
+MhUHdePNMU5bWmqK1qOo9jgZV33l7T06Alit3M8f8JoA2LwEYT/jHtS3upi+cXP+
+vWIs6tAaqdoDEmff6FxSd1EXEYHwo3yf+ASQJ6z66nwC5KrhW6L6Uo6bxm4F5Hfw
+jU0fyXeeFVCn7Nxw0SlxmA02Z70VFsL8BK9i3kajU18y6drf4VUm55oMEtdEmOh2
+eKn4qspBcNblbw+L0QJ+5kN1iRUyJHesQ1GpS+L3yeMVFCW7ctL4Bgw8Z7LE+z7i
+C0Weyhul8vuT+7nfF2T37zsSa8iixqpkTokeYh96CZ5nDqa2IDx3oNHWSlkIsV6g
+6EUEl9gBAoGBAPIw/M6fIDetMj8f1wG7mIRgJsxI817IS6aBSwB5HkoCJFfrR9Ua
+jMNCFIWNs/Om8xeGhq/91hbnCYDNK06V5CUa/uk4CYRs2eQZ3FKoNowtp6u/ieuU
+qg8bXM/vR2VWtWVixAMdouT3+KtvlgaVmSnrPiwO4pecGrwu5NW1oJCFAoGBANNb
+aE3AcwTDYsqh0N/75G56Q5s1GZ6MCDQGQSh8IkxL6Vg59KnJiIKQ7AxNKFgJZMtY
+zZHaqjazeHjOGTiYiC7MMVJtCcOBEfjCouIG8btNYv7Y3dWnOXRZni2telAsRrH9
+xS5LaFdCRTjVAwSsppMGwiQtyl6sGLMyz0SXoYoHAoGAKdkFFb6xFm26zOV3hTkg
+9V6X1ZyVUL9TMwYMK5zB+w+7r+VbmBrqT6LPYPRHL8adImeARlCZ+YMaRUMuRHnp
+3e94NFwWaOdWDu/Y/f9KzZXl7us9rZMWf12+/77cm0oMNeSG8fLg/qdKNHUneyPG
+P1QCfiJkTMYQaIvBxpuHjvECgYAKlZ9JlYOtD2PZJfVh4il0ZucP1L7ts7GNeWq1
+7lGBZKPQ6UYZYqBVeZB4pTyJ/B5yGIZi8YJoruAvnJKixPC89zjZGeDNS59sx8KE
+cziT2rJEdPPXCULVUs+bFf70GOOJcl33jYsyI3139SLrjwHghwwd57UkvJWYE8lR
+dA6A7QKBgEfTC+NlzqLPhbB+HPl6CvcUczcXcI9M0heVz/DNMA+4pjxPnv2aeIwh
+cL2wq2xr+g1wDBWGVGkVSuZhXm5E6gDetdyVeJnbIUhVjBblnbhHV6GrudjbXGnJ
+W9cBgu6DswyHU2cOsqmimu8zLmG6/dQYFHt+kUWGxN8opCzVjgWa
-----END RSA PRIVATE KEY-----
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 4a3c751..8b59000 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -91,7 +91,6 @@
using websocketpp::lib::placeholders::_1;
m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
- // m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
m_client.set_message_handler(
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
@@ -218,7 +217,7 @@
}
}
if (wait) {
- LOG(INFO) << "wait.." << m_open;
+ // LOG(INFO) << "wait.." << m_open;
WaitABit();
continue;
}
@@ -292,7 +291,7 @@
false, 1, "int");
TCLAP::ValueArg<int> is_ssl_(
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
- false, 0, "int");
+ false, 1, "int");
cmd.add(server_ip_);
cmd.add(port_);
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 8244856..7338513 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,6 +38,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -45,6 +46,7 @@
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -54,6 +56,7 @@
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@@ -134,6 +137,7 @@
timestamp_prediction=TimestampPredictor,
rnnt=TransducerModel,
rnnt_unified=UnifiedTransducerModel,
+ sa_asr=SAASRModel,
),
type_check=FunASRModel,
default="asr",
@@ -175,6 +179,27 @@
type_check=AbsEncoder,
default="rnn",
)
+asr_encoder_choices = ClassChoices(
+ "asr_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="rnn",
+)
+spk_encoder_choices = ClassChoices(
+ "spk_encoder",
+ classes=dict(
+ resnet34_diar=ResNet34Diar,
+ ),
+ default="resnet34_diar",
+)
postencoder_choices = ClassChoices(
name="postencoder",
classes=dict(
@@ -197,6 +222,7 @@
paraformer_decoder_sanm=ParaformerSANMDecoder,
paraformer_decoder_san=ParaformerDecoderSAN,
contextual_paraformer_decoder=ContextualParaformerDecoder,
+ sa_decoder=SAAsrTransformerDecoder,
),
type_check=AbsDecoder,
default="rnn",
@@ -328,6 +354,12 @@
type=str2bool,
default=True,
help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--max_spk_num",
+ type=int_or_none,
+ default=None,
+ help="A text mapping int-id to token",
)
group.add_argument(
"--seg_dict_file",
@@ -1495,3 +1527,123 @@
#assert check_return_type(model)
return model
+
+
+class ASRTaskSAASR(ASRTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ # --asr_encoder and --asr_encoder_conf
+ asr_encoder_choices,
+ # --spk_encoder and --spk_encoder_conf
+ spk_encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 5. Encoder
+ asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
+ asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
+ spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
+ spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
+
+ # 7. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=asr_encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # 8. CTC
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
+ )
+
+ # import ipdb;ipdb.set_trace()
+ # 9. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ asr_encoder=asr_encoder,
+ spk_encoder=spk_encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index 4769758..9579483 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -39,7 +39,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.e2e_sa_asr import ESPnetASRModel
+from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
@@ -120,7 +120,7 @@
model_choices = ClassChoices(
"model",
classes=dict(
- asr=ESPnetASRModel,
+ asr=SAASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_bert=ParaformerBert,
@@ -620,4 +620,4 @@
initialize(model, args.init)
assert check_return_type(model)
- return model
+ return model
\ No newline at end of file
diff --git a/funasr/version.txt b/funasr/version.txt
index ee6cdce..b616048 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.6.1
+0.6.2
--
Gitblit v1.9.1