yhliang
2023-06-16 e8528b8f6208cee52ed9c02ecfa9185f84706502
Dev lyh (#645)

* update

* update

* fix bug

* fix bug
7个文件已修改
10个文件已添加
43 文件已重命名
1个文件已删除
2 文件已复制
1530 ■■■■■ 已修改文件
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/README.md 86 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml 102 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml 131 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh 50 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/apply_map.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/combine_data.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/compute_cmvn.py 134 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/compute_cmvn.sh 39 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/compute_cpcer.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/convert_model.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/copy_data_dir.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/data/get_reco2dur.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/data/get_utt2dur.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/data/split_data.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/download_and_untar.sh 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/download_xvector_model.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/fix_data_dir.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/format_wav_scp.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/format_wav_scp.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/gen_oracle_embedding.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/process_text_id.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/process_text_spk_merge.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/text_format.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/text_normalize.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/validate_data_dir.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/local/validate_text.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/path.sh 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/run.sh 435 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr/utils 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/README.md 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/asr_local.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/local 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/path.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/run.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa_asr_deprecated/utils 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_asr_model.py 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_sa_asr.py 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/default.py 117 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 152 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/sa_asr.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/conf/train_lm_transformer.yaml
File was deleted
egs/alimeeting/sa_asr/README.md
New file
@@ -0,0 +1,86 @@
# Get Started
Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
# Train
First you need to install the FunASR and ModelScope. ([installation](https://github.com/alibaba-damo-academy/FunASR#installation))
After the FunASR and ModelScope is installed, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory. The `.dataset` should organized as follow:
```shell
dataset
|—— Eval_Ali_far
|—— Eval_Ali_near
|—— Test_Ali_far
|—— Test_Ali_near
|—— Train_Ali_far
|—— Train_Ali_near
```
Then you can run this receipe by running:
```shell
bash run.sh --stage 0 --stop-stage 6
```
There are 8 stages in `run.sh`:
```shell
stage 0: Data preparation and remove the audio which is too long or too short.
stage 1: Speaker profile and CMVN Generation.
stage 2: Dictionary preparation.
stage 3: LM training (not supported).
stage 4: ASR Training.
stage 5: SA-ASR Training.
stage 6: Inference
stage 7: Inference with Test_2023_Ali_far
```
<!-- The baseline model is available on [ModelScope](https://www.modelscope.cn/models/damo/speech_saasr_asr-zh-cn-16k-alimeeting/summary). -->
# Infer
1. Download the final test set and extracted
2. Put the audios in `./dataset/Test_2023_Ali_far/` and put the `wav.scp`, `segments`, `utt2spk`, `spk2utt` in `./data/org/Test_2023_Ali_far/`.
3. Set the `test_2023` in `run.sh` should be  to `Test_2023_Ali_far`.
4. Run the `run.sh` as follow.
```shell
# Prepare test_2023 set
bash run.sh --stage 0 --stop-stage 1
# Decode test_2023 set
bash run.sh --stage 7 --stop-stage 7
```
# Format of Final Submission
Finally, you need to submit a file called `text_spk_merge` with the following format:
```shell
Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
...
```
Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end.  For more information, please refer to the results generated after executing the baseline code.
# Baseline Results
The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Test Set are also provided to show the impact of speaker profile accuracy.
<!-- <table>
    <tr >
        <td rowspan="2"></td>
        <td colspan="2">SI-CER(%)</td>
        <td colspan="2">cpCER(%)</td>
    </tr>
    <tr>
        <td>Eval</td>
        <td>Test</td>
        <td>Eval</td>
        <td>Test</td>
    </tr>
    <tr>
        <td>oracle profile</td>
        <td>32.05</td>
        <td>32.72</td>
        <td>47.40</td>
        <td>42.92</td>
    </tr>
    <tr>
        <td>cluster profile</td>
        <td>32.05</td>
        <td>32.73</td>
        <td>53.76</td>
        <td>49.37</td>
    </tr>
</table> -->
|                |SI-CER(%)     |cp-CER(%)  |
|:---------------|:------------:|----------:|
|oracle profile  |32.72         |42.92      |
|cluster  profile|32.73         |49.37      |
# Reference
N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413–4417.
egs/alimeeting/sa_asr/conf/decode_asr_rnn.yaml
egs/alimeeting/sa_asr/conf/train_asr_conformer.yaml
New file
@@ -0,0 +1,102 @@
# network architecture
frontend: multichannelfrontend
frontend_conf:
    fs: 16000
    window: hann
    n_fft: 400
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
    use_channel: 0
# encoder related
encoder: conformer
encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
    linear_units: 2048  # the number of units of position-wise feed forward
    num_blocks: 12      # the number of encoder blocks
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.0
    input_layer: conv2d # encoder architecture type
    normalize_before: true
    rel_pos_type: latest
    pos_enc_layer_type: rel_pos
    selfattention_layer_type: rel_selfattn
    activation_type: swish
    macaron_style: true
    use_cnn_module: true
    cnn_module_kernel: 15
# decoder related
decoder: transformer
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
# ctc related
ctc_conf:
    ignore_nan_grad: true
# hybrid CTC/attention
model_conf:
    ctc_weight: 0.3
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: false
dataset_conf:
    data_names: speech,text
    data_types: sound,text
    shuffle: True
    shuffle_conf:
        shuffle_size: 2048
        sort_size: 500
    batch_conf:
        batch_type: token
        batch_size: 7000
    num_workers: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 100
val_scheduler_criterion:
    - valid
    - acc
best_model_criterion:
-   - valid
    - acc
    - max
keep_nbest_models: 10
optim: adam
optim_conf:
   lr: 0.001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 25000
specaug: specaug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 40
    num_time_mask: 2
egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
New file
@@ -0,0 +1,131 @@
# network architecture
frontend: multichannelfrontend
frontend_conf:
    fs: 16000
    window: hann
    n_fft: 400
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
    use_channel: 0
# encoder related
asr_encoder: conformer
asr_encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
    linear_units: 2048  # the number of units of position-wise feed forward
    num_blocks: 12      # the number of encoder blocks
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.0
    input_layer: conv2d # encoder architecture type
    normalize_before: true
    pos_enc_layer_type: rel_pos
    selfattention_layer_type: rel_selfattn
    activation_type: swish
    macaron_style: true
    use_cnn_module: true
    cnn_module_kernel: 15
spk_encoder: resnet34_diar
spk_encoder_conf:
  use_head_conv: true
  batchnorm_momentum: 0.5
  use_head_maxpool: false
  num_nodes_pooling_layer: 256
  layers_in_block:
    - 3
    - 4
    - 6
    - 3
  filters_in_block:
    - 32
    - 64
    - 128
    - 256
  pooling_type: statistic
  num_nodes_resnet1: 256
  num_nodes_last_layer: 256
  batchnorm_momentum: 0.5
# decoder related
decoder: sa_decoder
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    asr_num_blocks: 6
    spk_num_blocks: 3
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
    spk_weight: 0.5
    ctc_weight: 0.3
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: false
    max_spk_num: 4
ctc_conf:
    ignore_nan_grad: true
# minibatch related
dataset_conf:
    data_names: speech,text,profile,text_id
    data_types: sound,text,npy,text_int
    shuffle: True
    shuffle_conf:
        shuffle_size: 2048
        sort_size: 500
    batch_conf:
        batch_type: token
        batch_size: 7000
    num_workers: 8
# optimization related
accum_grad: 1
grad_clip: 5
max_epoch: 60
val_scheduler_criterion:
    - valid
    - loss
best_model_criterion:
-   - valid
    - acc
    - max
-   - valid
    - acc_spk
    - max
-   - valid
    - loss
    - min
keep_nbest_models: 10
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 8000
specaug: specaug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 40
    num_time_mask: 2
egs/alimeeting/sa_asr/local/alimeeting_data_prep.sh
File was renamed from egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -21,6 +21,8 @@
SECONDS=0
tgt=Train #Train or Eval
min_wav_duration=0.1
max_wav_duration=20
log "$0 $*"
@@ -57,27 +59,24 @@
stop_stage=4
mkdir -p $far_dir
mkdir -p $near_dir
mkdir -p data/org
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 
    log "stage 1:process alimeeting near dir"
    
    find -L $near_raw_dir/audio_dir -iname "*.wav" | sort >  $near_dir/wavlist
    awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' > $near_dir/uttid
    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" | sort > $near_dir/textgrid.flist
    awk -F '/' '{print $NF}' $near_dir/wavlist | awk -F '.' '{print $1}' | sort > $near_dir/uttid
    find -L $near_raw_dir/textgrid_dir  -iname "*.TextGrid" > $near_dir/textgrid.flist
    n1_wav=$(wc -l < $near_dir/wavlist)
    n2_text=$(wc -l < $near_dir/textgrid.flist)
    log  near file found $n1_wav wav and $n2_text text.
    paste $near_dir/uttid $near_dir/wavlist > $near_dir/wav_raw.scp
    # cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -c 1 -t wav  - |\n", $1, $2)}'  > $near_dir/wav.scp
    cat $near_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $near_dir/wav.scp
    paste $near_dir/uttid $near_dir/wavlist -d " " > $near_dir/wav.scp
    
    python local/alimeeting_process_textgrid.py --path $near_dir --no-overlap False
    cat $near_dir/text_all | local/text_normalize.pl | local/text_format.pl | sort -u > $near_dir/text
    utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $near_dir/utt2spk_old >$near_dir/tmp1
    #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
    local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
    utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
    sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
@@ -97,9 +96,7 @@
    n2_text=$(wc -l < $far_dir/textgrid.flist)
    log  far file found $n1_wav wav and $n2_text text.
    paste $far_dir/uttid $far_dir/wavlist > $far_dir/wav_raw.scp
    cat $far_dir/wav_raw.scp | awk '{printf("%s sox -t wav  %s -r 16000 -b 16 -t wav  - |\n", $1, $2)}'  > $far_dir/wav.scp
    paste $far_dir/uttid $far_dir/wavlist -d " " > $far_dir/wav.scp
    python local/alimeeting_process_overlap_force.py  --path $far_dir \
        --no-overlap false --mars True \
@@ -119,28 +116,28 @@
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    log "stage 3: finali data process"
    log "stage 3: final data process"
    local/fix_data_dir.sh $near_dir
    local/fix_data_dir.sh $far_dir
    local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
    local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
    local/copy_data_dir.sh $near_dir data/org/${tgt}_Ali_near
    local/copy_data_dir.sh $far_dir data/org/${tgt}_Ali_far
    sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
    sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
    sort $far_dir/utt2spk_all_fifo > data/org/${tgt}_Ali_far/utt2spk_all_fifo
    sed -i "s/src/$/g" data/org/${tgt}_Ali_far/utt2spk_all_fifo
    # remove space in text
    for x in ${tgt}_Ali_near ${tgt}_Ali_far; do
        cp data/${x}/text data/${x}/text.org
        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
        > data/${x}/text
        rm data/${x}/text.org
        cp data/org/${x}/text data/org/${x}/text.org
        paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
        > data/org/${x}/text
        rm data/org/${x}/text.org
    done
    log "Successfully finished. [elapsed=${SECONDS}s]"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    log "stage 4: process alimeeting far dir (single speaker by oracle time strap)"
    log "stage 4: process alimeeting far dir (single speaker by oracle time stamp)"
    cp -r $far_dir/* $far_single_speaker_dir 
    mv $far_single_speaker_dir/textgrid.flist  $far_single_speaker_dir/textgrid_oldpath
    paste -d " " $far_single_speaker_dir/uttid $far_single_speaker_dir/textgrid_oldpath > $far_single_speaker_dir/textgrid.flist
@@ -150,14 +147,15 @@
    local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
    ./local/fix_data_dir.sh $far_single_speaker_dir 
    local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
    local/copy_data_dir.sh $far_single_speaker_dir data/org/${tgt}_Ali_far_single_speaker
    # remove space in text
    for x in ${tgt}_Ali_far_single_speaker; do
        cp data/${x}/text data/${x}/text.org
        paste -d " " <(cut -f 1 -d" " data/${x}/text.org) <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \
        > data/${x}/text
        rm data/${x}/text.org
        cp data/org/${x}/text data/org/${x}/text.org
        paste -d " " <(cut -f 1 -d" " data/org/${x}/text.org) <(cut -f 2- -d" " data/org/${x}/text.org | tr -d " ") \
        > data/org/${x}/text
        rm data/org/${x}/text.org
    done
    rm -rf data/local
    log "Successfully finished. [elapsed=${SECONDS}s]"
fi
egs/alimeeting/sa_asr/local/alimeeting_data_prep_test_2023.sh
egs/alimeeting/sa_asr/local/alimeeting_process_overlap_force.py
egs/alimeeting/sa_asr/local/alimeeting_process_textgrid.py
egs/alimeeting/sa_asr/local/apply_map.pl
egs/alimeeting/sa_asr/local/combine_data.sh
egs/alimeeting/sa_asr/local/compute_cmvn.py
New file
@@ -0,0 +1,134 @@
import argparse
import json
import os
import numpy as np
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
from funasr.models.frontend.default import DefaultFrontend
import torch
def get_parser():
    parser = argparse.ArgumentParser(
        description="computer global cmvn",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--dim",
        default=80,
        type=int,
        help="feature dimension",
    )
    parser.add_argument(
        "--wav_path",
        default=False,
        required=True,
        type=str,
        help="the path of wav scps",
    )
    parser.add_argument(
        "--config_file",
        type=str,
        help="the config file for computing cmvn",
    )
    parser.add_argument(
        "--idx",
        default=1,
        required=True,
        type=int,
        help="index",
    )
    return parser
def compute_fbank(wav_file,
                  num_mel_bins=80,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0,
                  resample_rate=16000,
                  speed=1.0,
                  window_type="hamming"):
    waveform, sample_rate = torchaudio.load(wav_file)
    if resample_rate != sample_rate:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
                                                  new_freq=resample_rate)(waveform)
    if speed != 1.0:
        waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
            waveform, resample_rate,
            [['speed', str(speed)], ['rate', str(resample_rate)]]
        )
    waveform = waveform * (1 << 15)
    mat = kaldi.fbank(waveform,
                      num_mel_bins=num_mel_bins,
                      frame_length=frame_length,
                      frame_shift=frame_shift,
                      dither=dither,
                      energy_floor=0.0,
                      window_type=window_type,
                      sample_frequency=resample_rate)
    return mat.numpy()
def main():
    parser = get_parser()
    args = parser.parse_args()
    wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
    cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
    mean_stats = np.zeros(args.dim)
    var_stats = np.zeros(args.dim)
    total_frames = 0
    # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
    #     for key, mat in ark_reader:
    #         mean_stats += np.sum(mat, axis=0)
    #         var_stats += np.sum(np.square(mat), axis=0)
    #         total_frames += mat.shape[0]
    with open(args.config_file) as f:
        configs = yaml.safe_load(f)
        frontend_configs = configs.get("frontend_conf", {})
        num_mel_bins = frontend_configs.get("n_mels", 80)
        frame_length = frontend_configs.get("frame_length", 25)
        frame_shift = frontend_configs.get("frame_shift", 10)
        window_type = frontend_configs.get("window", "hamming")
        resample_rate = frontend_configs.get("fs", 16000)
        n_fft = frontend_configs.get("n_fft", "400")
        use_channel = frontend_configs.get("use_channel", None)
        assert num_mel_bins == args.dim
    frontend = DefaultFrontend(
        fs=resample_rate,
        n_fft=n_fft,
        win_length=frame_length * 16,
        hop_length=frame_shift * 16,
        window=window_type,
        n_mels=num_mel_bins,
        use_channel=use_channel,
    )
    with open(wav_scp_file) as f:
        lines = f.readlines()
        for line in lines:
            _, wav_file = line.strip().split()
            wavform, _ = torchaudio.load(wav_file)
            fbank, _ = frontend(wavform.transpose(0, 1).unsqueeze(0), torch.tensor([wavform.shape[1]]))
            fbank = fbank.squeeze(0).numpy()
            mean_stats += np.sum(fbank, axis=0)
            var_stats += np.sum(np.square(fbank), axis=0)
            total_frames += fbank.shape[0]
    cmvn_info = {
        'mean_stats': list(mean_stats.tolist()),
        'var_stats': list(var_stats.tolist()),
        'total_frames': total_frames
    }
    with open(cmvn_file, 'w') as fout:
        fout.write(json.dumps(cmvn_info))
if __name__ == '__main__':
    main()
egs/alimeeting/sa_asr/local/compute_cmvn.sh
New file
@@ -0,0 +1,39 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# Begin configuration section.
fbankdir=
nj=32
cmd=./utils/run.pl
feats_dim=80
config_file=
scale=1.0
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# shellcheck disable=SC2046
head -n $(awk -v lines="$(wc -l < ${fbankdir}/wav.scp)" -v scale="$scale" 'BEGIN { printf "%.0f\n", lines*scale }') ${fbankdir}/wav.scp > ${fbankdir}/wav.scp.scale
split_dir=${fbankdir}/cmvn/split_${nj};
mkdir -p $split_dir
split_scps=""
for n in $(seq $nj); do
    split_scps="$split_scps $split_dir/wav.$n.scp"
done
utils/split_scp.pl ${fbankdir}/wav.scp.scale $split_scps || exit 1;
logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
    python local/compute_cmvn.py \
      --dim ${feats_dim} \
      --wav_path $split_dir \
      --config_file $config_file \
      --idx JOB \
python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/am.mvn
echo "$0: Succeeded compute global cmvn"
egs/alimeeting/sa_asr/local/compute_cpcer.py
egs/alimeeting/sa_asr/local/convert_model.py
New file
@@ -0,0 +1,29 @@
import codecs
import pdb
import sys
import torch
char1 = sys.argv[1]
char2 = sys.argv[2]
model1 = torch.load(sys.argv[3], map_location='cpu')
model2_path = sys.argv[4]
d_new = model1
char1_list = []
map_list = []
with codecs.open(char1) as f:
    for line in f.readlines():
        char1_list.append(line.strip())
with codecs.open(char2) as f:
    for line in f.readlines():
        map_list.append(char1_list.index(line.strip()))
print(map_list)
for k, v in d_new.items():
    if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
        d_new[k] = v[map_list]
torch.save(d_new, model2_path)
egs/alimeeting/sa_asr/local/copy_data_dir.sh
egs/alimeeting/sa_asr/local/data/get_reco2dur.sh
egs/alimeeting/sa_asr/local/data/get_segments_for_data.sh
egs/alimeeting/sa_asr/local/data/get_utt2dur.sh
egs/alimeeting/sa_asr/local/data/split_data.sh
egs/alimeeting/sa_asr/local/download_and_untar.sh
New file
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
#             2017  Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
  remove_archive=true
  shift
fi
if [ $# -ne 3 ]; then
  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
  echo "$0: no such directory $data"
  exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
  if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
  exit 1;
fi
if [ -z "$url" ]; then
  echo "$0: empty URL base."
  exit 1;
fi
if [ -f $data/$part/.complete ]; then
  echo "$0: data part $part was already successfully extracted, nothing to do."
  exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
  size_ok=false
  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
  if ! $size_ok; then
    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
    echo "does not equal the size of one of the archives."
    rm $data/$part.tgz
  else
    echo "$data/$part.tgz exists and appears to be complete."
  fi
fi
if [ ! -f $data/$part.tgz ]; then
  if ! command -v wget >/dev/null; then
    echo "$0: wget is not installed."
    exit 1;
  fi
  full_url=$url/$part.tgz
  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
  cd $data || exit 1
  if ! wget --no-check-certificate $full_url; then
    echo "$0: error executing wget $full_url"
    exit 1;
  fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
  echo "$0: error un-tarring archive $data/$part.tgz"
  exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
  cd $data/$part/wav || exit 1
  for wav in ./*.tar.gz; do
    echo "Extracting wav from $wav"
    tar -zxf $wav && rm $wav
  done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
  rm $data/$part.tgz
fi
exit 0;
egs/alimeeting/sa_asr/local/download_pretrained_model_from_modelscope.py
egs/alimeeting/sa_asr/local/download_xvector_model.py
egs/alimeeting/sa_asr/local/filter_utt2spk_all_fifo.py
egs/alimeeting/sa_asr/local/fix_data_dir.sh
egs/alimeeting/sa_asr/local/format_wav_scp.py
egs/alimeeting/sa_asr/local/format_wav_scp.sh
egs/alimeeting/sa_asr/local/gen_cluster_profile_infer.py
File was renamed from egs/alimeeting/sa-asr/local/gen_cluster_profile_infer.py
@@ -63,7 +63,7 @@
    wav_scp_file = open(path+'/wav.scp', 'r')
    wav_scp = wav_scp_file.readlines()
    wav_scp_file.close()
    raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
    raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
    raw_meeting_scp = raw_meeting_scp_file.readlines()
    raw_meeting_scp_file.close()
    segments_scp_file = open(raw_path + '/segments', 'r')
@@ -92,8 +92,8 @@
    cluster_spk_num_file = open(path + '/cluster_spk_num', 'w')
    meeting_map = {}
    for line in raw_meeting_scp:
        meeting = line.strip().split('\t')[0]
        wav_path = line.strip().split('\t')[1]
        meeting = line.strip().split(' ')[0]
        wav_path = line.strip().split(' ')[1]
        wav = soundfile.read(wav_path)[0]
        # take the first channel
        if wav.ndim == 2:
egs/alimeeting/sa_asr/local/gen_oracle_embedding.py
File was renamed from egs/alimeeting/sa-asr/local/gen_oracle_embedding.py
@@ -9,7 +9,7 @@
if __name__=="__main__":
    path = sys.argv[1] # dump2/raw/Eval_Ali_far
    raw_path = sys.argv[2] # data/local/Eval_Ali_far_correct_single_speaker
    raw_meeting_scp_file = open(raw_path + '/wav_raw.scp', 'r')
    raw_meeting_scp_file = open(raw_path + '/wav.scp', 'r')
    raw_meeting_scp = raw_meeting_scp_file.readlines()
    raw_meeting_scp_file.close()
    segments_scp_file = open(raw_path + '/segments', 'r')
@@ -22,8 +22,8 @@
    raw_wav_map = {}
    for line in raw_meeting_scp:
        meeting = line.strip().split('\t')[0]
        wav_path = line.strip().split('\t')[1]
        meeting = line.strip().split(' ')[0]
        wav_path = line.strip().split(' ')[1]
        raw_wav_map[meeting] = wav_path
    
    spk_map = {}
egs/alimeeting/sa_asr/local/gen_oracle_profile_nopadding.py
egs/alimeeting/sa_asr/local/gen_oracle_profile_padding.py
File was renamed from egs/alimeeting/sa-asr/local/gen_oracle_profile_padding.py
@@ -5,7 +5,7 @@
if __name__=="__main__":
    path = sys.argv[1] # dump2/raw/Train_Ali_far
    path = sys.argv[1]
    wav_scp_file = open(path+"/wav.scp", 'r')
    wav_scp = wav_scp_file.readlines()
    wav_scp_file.close()
@@ -29,7 +29,7 @@
        line_list = line.strip().split(' ')
        meeting = line_list[0].split('-')[0]
        spk_id = line_list[0].split('-')[-1].split('_')[-1]
        spk = meeting+'_' + spk_id
        spk = meeting + '_' + spk_id
        global_spk_list.append(spk)
        if meeting in meeting_map_tmp.keys():
            meeting_map_tmp[meeting].append(spk)
egs/alimeeting/sa_asr/local/perturb_data_dir_speed.sh
egs/alimeeting/sa_asr/local/process_sot_fifo_textchar2spk.py
File was renamed from egs/alimeeting/sa-asr/local/process_sot_fifo_textchar2spk.py
@@ -30,8 +30,7 @@
    meetingid_map = {}
    for line in spk2utt:
        spkid = line.strip().split(" ")[0]
        meeting_id_list = spkid.split("_")[:3]
        meeting_id = meeting_id_list[0] + "_" + meeting_id_list[1] + "_" + meeting_id_list[2]
        meeting_id = spkid.split("-")[0]
        if meeting_id not in meetingid_map:
            meetingid_map[meeting_id] = 1     
        else:
egs/alimeeting/sa_asr/local/process_text_id.py
egs/alimeeting/sa_asr/local/process_text_spk_merge.py
egs/alimeeting/sa_asr/local/process_textgrid_to_single_speaker_wav.py
egs/alimeeting/sa_asr/local/spk2utt_to_utt2spk.pl
egs/alimeeting/sa_asr/local/text_format.pl
egs/alimeeting/sa_asr/local/text_normalize.pl
egs/alimeeting/sa_asr/local/utt2spk_to_spk2utt.pl
egs/alimeeting/sa_asr/local/validate_data_dir.sh
egs/alimeeting/sa_asr/local/validate_text.pl
egs/alimeeting/sa_asr/path.sh
New file
@@ -0,0 +1,6 @@
export FUNASR_DIR=$PWD/../../..
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
egs/alimeeting/sa_asr/run.sh
New file
@@ -0,0 +1,435 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="6,7"
gpu_num=2
count=1
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=8
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="data" #feature output dictionary
exp_dir="exp"
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="1.0"
min_wav_duration=0.1
max_wav_duration=20
profile_modes="cluster oracle"
stage=7
stop_stage=7
# feature configuration
feats_dim=80
nj=32
# data
raw_data=
data_url=
# exp tag
tag=""
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=Train_Ali_far
valid_set=Eval_Ali_far
test_sets="Test_Ali_far Eval_Ali_far"
test_2023="Test_2023_Ali_far_release"
asr_config=conf/train_asr_conformer.yaml
sa_asr_config=conf/train_sa_asr_conformer.yaml
asr_model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
sa_asr_model_dir="baseline_$(basename "${sa_asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_rnn.yaml
inference_sa_asr_model=valid.acc_spk.ave.pb
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
    inference_nj=$[${ngpu}*${njob}]
    _ngpu=1
else
    inference_nj=$njob
    _ngpu=0
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Data preparation"
    # Data preparation
    ./local/alimeeting_data_prep.sh --tgt Test --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
    ./local/alimeeting_data_prep.sh --tgt Eval --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
    ./local/alimeeting_data_prep.sh --tgt Train --min_wav_duration $min_wav_duration --max_wav_duration $max_wav_duration
    # remove long/short data
    for x in ${train_set} ${valid_set} ${test_sets}; do
        cp -r ${feats_dir}/org/${x} ${feats_dir}/${x}
        rm ${feats_dir}/"${x}"/wav.scp ${feats_dir}/"${x}"/segments
        local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
            --audio-format wav --segments ${feats_dir}/org/${x}/segments \
            "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
        _min_length=$(python3 -c "print(int(${min_wav_duration} * 16000))")
        _max_length=$(python3 -c "print(int(${max_wav_duration} * 16000))")
        <"${feats_dir}/${x}/utt2num_samples" \
        awk '{if($2 > '$_min_length' && $2 < '$_max_length')print $0;}' \
            >"${feats_dir}/${x}/utt2num_samples_rmls"
        mv ${feats_dir}/${x}/utt2num_samples_rmls ${feats_dir}/${x}/utt2num_samples
        <"${feats_dir}/${x}/wav.scp" \
            utils/filter_scp.pl "${feats_dir}/${x}/utt2num_samples"  \
            >"${feats_dir}/${x}/wav.scp_rmls"
        mv ${feats_dir}/${x}/wav.scp_rmls ${feats_dir}/${x}/wav.scp
        <"${feats_dir}/${x}/text" \
            awk '{ if( NF != 1 ) print $0; }' >"${feats_dir}/${x}/text_rmblank"
        mv ${feats_dir}/${x}/text_rmblank ${feats_dir}/${x}/text
        local/fix_${feats_dir}_dir.sh "${feats_dir}/${x}"
        <"${feats_dir}/${x}/utt2spk_all_fifo" \
            utils/filter_scp.pl "${feats_dir}/${x}/text"  \
            >"${feats_dir}/${x}/utt2spk_all_fifo_rmls"
        mv "${feats_dir}/${x}/utt2spk_all_fifo_rmls" "${feats_dir}/${x}/utt2spk_all_fifo"
    done
    for x in ${test_2023}; do
        local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
            --audio-format wav --segments ${feats_dir}/org/${x}/segments \
            "${feats_dir}/org/${x}/${scp}" "${feats_dir}/${x}"
        cut -d " " -f1 ${feats_dir}/${x}/wav.scp > ${feats_dir}/${x}/uttid
        paste -d " " ${feats_dir}/${x}/uttid ${feats_dir}/${x}/uttid > ${feats_dir}/${x}/utt2spk
        cp ${feats_dir}/${x}/utt2spk ${feats_dir}/${x}/spk2utt
    done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Speaker profile and CMVN Generation"
    mkdir -p "profile_log"
    for x in "${train_set}" "${valid_set}" "${test_sets}"; do
        # generate text_id spk2id
        python local/process_sot_fifo_textchar2spk.py --path ${feats_dir}/${x}
        echo "Successfully generate ${feats_dir}/${x}/text_id ${feats_dir}/${x}/spk2id"
        # generate text_id_train for sot
        python local/process_text_id.py ${feats_dir}/${x}
        echo "Successfully generate ${feats_dir}/${x}/text_id_train"
        # generate oracle_embedding from single-speaker audio segment
        echo "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${x}.log"
        python local/gen_oracle_embedding.py "${feats_dir}/${x}" "data/org/${x}_single_speaker" &> "profile_log/gen_oracle_embedding_${x}.log"
        echo "Successfully generate oracle embedding for ${x} (${feats_dir}/${x}/oracle_embedding.scp)"
        # generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
        if [ "${x}" = "${train_set}" ]; then
            python local/gen_oracle_profile_padding.py ${feats_dir}/${x}
            echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_padding.scp)"
        else
            python local/gen_oracle_profile_nopadding.py ${feats_dir}/${x}
            echo "Successfully generate oracle profile for ${x} (${feats_dir}/${x}/oracle_profile_nopadding.scp)"
        fi
        # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
        if [ "${x}" = "${valid_set}" ] || [ "${x}" = "${test_sets}" ]; then
            echo "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${x}.log"
            python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
            echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
        fi
        # compute CMVN
        if [ "${x}" = "${train_set}" ]; then
            local/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --fbankdir ${feats_dir}/${train_set} --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
        fi
    done
    for x in "${test_2023}"; do
        # generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
        python local/gen_cluster_profile_infer.py "${feats_dir}/${x}" "${feats_dir}/org/${x}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${x}.log"
        echo "Successfully generate cluster profile for ${x} (${feats_dir}/${x}/cluster_profile_infer.scp)"
    done
fi
token_list=${feats_dir}/${lang}_token_list/char/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: Dictionary Preparation"
    mkdir -p ${feats_dir}/${lang}_token_list/char/
    echo "make a dictionary"
    echo "<blank>" > ${token_list}
    echo "<s>" >> ${token_list}
    echo "</s>" >> ${token_list}
    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
fi
# LM Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Training"
fi
# ASR Training Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    echo "Stage 4: ASR Training"
    asr_exp=${exp_dir}/${asr_model_dir}
    mkdir -p ${asr_exp}
    mkdir -p ${asr_exp}/log
    INIT_FILE=${asr_exp}/ddp_init
    if [ -f $INIT_FILE ];then
        rm -f $INIT_FILE
    fi
    init_method=file://$(readlink -f $INIT_FILE)
    echo "$0: init method is $init_method"
    for ((i = 0; i < $ngpu; ++i)); do
        {
            # i=0
            rank=$i
            local_rank=$i
            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
            train.py \
                --task_name asr \
                --model asr \
                --gpu_id $gpu_id \
                --use_preprocessor true \
                --split_with_space false \
                --token_type char \
                --token_list $token_list \
                --data_dir ${feats_dir} \
                --train_set ${train_set} \
                --valid_set ${valid_set} \
                --data_file_names "wav.scp,text" \
                --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
                --speed_perturb ${speed_perturb} \
                --resume true \
                --output_dir ${exp_dir}/${asr_model_dir} \
                --config $asr_config \
                --ngpu $gpu_num \
                --num_worker_count $count \
                --dist_init_method $init_method \
                --dist_world_size $world_size \
                --dist_rank $rank \
                --local_rank $local_rank 1> ${exp_dir}/${asr_model_dir}/log/train.log.$i 2>&1
        } &
    done
    wait
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
    echo "SA-ASR training"
    asr_exp=${exp_dir}/${asr_model_dir}
    sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
    mkdir -p ${sa_asr_exp}
    mkdir -p ${sa_asr_exp}/log
    INIT_FILE=${sa_asr_exp}/ddp_init
    if [ ! -L ${feats_dir}/${train_set}/profile.scp ]; then
        ln -sr ${feats_dir}/${train_set}/oracle_profile_padding.scp ${feats_dir}/${train_set}/profile.scp
        ln -sr ${feats_dir}/${valid_set}/oracle_profile_nopadding.scp ${feats_dir}/${valid_set}/profile.scp
    fi
    if [ ! -f "${exp_dir}/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth" ]; then
        # download xvector extractor model file
        python local/download_xvector_model.py ${exp_dir}
        echo "Successfully download the pretrained xvector extractor to exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth"
    fi
    if [ -f $INIT_FILE ];then
        rm -f $INIT_FILE
    fi
    init_method=file://$(readlink -f $INIT_FILE)
    echo "$0: init method is $init_method"
    for ((i = 0; i < $ngpu; ++i)); do
        {
            rank=$i
            local_rank=$i
            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
            train.py \
                --task_name asr \
                --model sa_asr \
                --gpu_id $gpu_id \
                --use_preprocessor true \
                --split_with_space false \
                --unused_parameters true \
                --token_type char \
                --resume true \
                --token_list $token_list \
                --data_dir ${feats_dir} \
                --train_set ${train_set} \
                --valid_set ${valid_set} \
                --data_file_names "wav.scp,text,profile.scp,text_id_train" \
                --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
                --speed_perturb ${speed_perturb} \
                --init_param "${asr_exp}/valid.acc.ave.pb:encoder:asr_encoder"   \
                --init_param "${asr_exp}/valid.acc.ave.pb:ctc:ctc"   \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.embed:decoder.embed" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.output_layer:decoder.asr_output_layer" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.self_attn:decoder.decoder1.self_attn" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.src_attn:decoder.decoder3.src_attn" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.0.feed_forward:decoder.decoder3.feed_forward" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.1:decoder.decoder4.0" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.2:decoder.decoder4.1" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.3:decoder.decoder4.2" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.4:decoder.decoder4.3" \
                --init_param "${asr_exp}/valid.acc.ave.pb:decoder.decoders.5:decoder.decoder4.4" \
                --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:encoder:spk_encoder"   \
                --init_param "exp/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth:decoder:spk_encoder:decoder.output_dense"   \
                --output_dir ${exp_dir}/${sa_asr_model_dir} \
                --config $sa_asr_config \
                --ngpu $gpu_num \
                --num_worker_count $count \
                --dist_init_method $init_method \
                --dist_world_size $world_size \
                --dist_rank $rank \
                --local_rank $local_rank 1> ${exp_dir}/${sa_asr_model_dir}/log/train.log.$i 2>&1
        } &
    done
    wait
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
    echo "stage 6: Inference test sets"
    for x in ${test_sets}; do
        for profile_mode in ${profile_modes}; do
            echo "decoding ${x} with ${profile_mode} profile"
            sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
            inference_tag="$(basename "${inference_config}" .yaml)"
            _dir="${sa_asr_exp}/${inference_tag}_${profile_mode}/${inference_sa_asr_model}/${x}"
            _logdir="${_dir}/logdir"
            if [ -d ${_dir} ]; then
                echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
                exit 0
            fi
            mkdir -p "${_logdir}"
            _data="${feats_dir}/${x}"
            key_file=${_data}/${scp}
            num_scp_file="$(<${key_file} wc -l)"
            _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
            split_scps=
            for n in $(seq "${_nj}"); do
                split_scps+=" ${_logdir}/keys.${n}.scp"
            done
            # shellcheck disable=SC2086
            utils/split_scp.pl "${key_file}" ${split_scps}
            _opts=
            if [ -n "${inference_config}" ]; then
                _opts+="--config ${inference_config} "
            fi
            if [ $profile_mode = "oracle" ]; then
                profile_scp=${profile_mode}_profile_nopadding.scp
            else
                profile_scp=${profile_mode}_profile_infer.scp
            fi
            ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                python -m funasr.bin.asr_inference_launch \
                    --batch_size 1 \
                    --mc True \
                    --ngpu "${_ngpu}" \
                    --njob ${njob} \
                    --nbest 1 \
                    --gpuid_list ${gpuid_list} \
                    --allow_variable_data_keys true \
                    --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
                    --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
                    --data_path_and_name_and_type "${_data}/$profile_scp,profile,npy" \
                    --key_file "${_logdir}"/keys.JOB.scp \
                    --asr_train_config "${sa_asr_exp}"/config.yaml \
                    --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
                    --output_dir "${_logdir}"/output.JOB \
                    --mode sa_asr \
                    ${_opts}
            for f in token token_int score text text_id; do
                if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
                    for i in $(seq "${_nj}"); do
                        cat "${_logdir}/output.${i}/1best_recog/${f}"
                    done | sort -k1 >"${_dir}/${f}"
                fi
            done
            sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
            sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
            python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
            python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
            cat ${_dir}/text.cer.txt
            python local/process_text_spk_merge.py ${_dir}
            python local/process_text_spk_merge.py ${_data}
            python local/compute_cpcer.py ${_data}/text_spk_merge ${_dir}/text_spk_merge ${_dir}/text.cpcer
            tail -n 1 ${_dir}/text.cpcer > ${_dir}/text.cpcer.txt
            cat ${_dir}/text.cpcer.txt
        done
    done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
    echo "stage 7: Inference test 2023"
    for x in ${test_2023}; do
        sa_asr_exp=${exp_dir}/${sa_asr_model_dir}
        inference_tag="$(basename "${inference_config}" .yaml)"
        _dir="${sa_asr_exp}/${inference_tag}_cluster/${inference_sa_asr_model}/${x}"
        _logdir="${_dir}/logdir"
        if [ -d ${_dir} ]; then
            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
            exit 0
        fi
        mkdir -p "${_logdir}"
        _data="${feats_dir}/${x}"
        key_file=${_data}/${scp}
        num_scp_file="$(<${key_file} wc -l)"
        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
        split_scps=
        for n in $(seq "${_nj}"); do
            split_scps+=" ${_logdir}/keys.${n}.scp"
        done
        # shellcheck disable=SC2086
        utils/split_scp.pl "${key_file}" ${split_scps}
        _opts=
        if [ -n "${inference_config}" ]; then
            _opts+="--config ${inference_config} "
        fi
        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
            python -m funasr.bin.asr_inference_launch \
                --batch_size 1 \
                --mc True \
                --ngpu "${_ngpu}" \
                --njob ${njob} \
                --nbest 1 \
                --gpuid_list ${gpuid_list} \
                --allow_variable_data_keys true \
                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
                --data_path_and_name_and_type "${_data}/cluster_profile_infer.scp,profile,npy" \
                --cmvn_file ${feats_dir}/${train_set}/cmvn/cmvn.mvn \
                --key_file "${_logdir}"/keys.JOB.scp \
                --asr_train_config "${sa_asr_exp}"/config.yaml \
                --asr_model_file "${sa_asr_exp}"/"${inference_sa_asr_model}" \
                --output_dir "${_logdir}"/output.JOB \
                --mode sa_asr \
                ${_opts}
        for f in token token_int score text text_id; do
            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
                for i in $(seq "${_nj}"); do
                    cat "${_logdir}/output.${i}/1best_recog/${f}"
                done | sort -k1 >"${_dir}/${f}"
            fi
        done
        python local/process_text_spk_merge.py ${_dir}
    done
fi
egs/alimeeting/sa_asr/utils
egs/alimeeting/sa_asr_deprecated/README.md
egs/alimeeting/sa_asr_deprecated/asr_local.sh
egs/alimeeting/sa_asr_deprecated/asr_local_m2met_2023_infer.sh
egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
copy from egs/alimeeting/sa-asr/conf/decode_asr_rnn.yaml copy to egs/alimeeting/sa_asr_deprecated/conf/decode_asr_rnn.yaml
egs/alimeeting/sa_asr_deprecated/conf/train_asr_conformer.yaml
egs/alimeeting/sa_asr_deprecated/conf/train_sa_asr_conformer.yaml
egs/alimeeting/sa_asr_deprecated/local
New file
@@ -0,0 +1 @@
../sa_asr/local/
egs/alimeeting/sa_asr_deprecated/path.sh
egs/alimeeting/sa_asr_deprecated/run.sh
egs/alimeeting/sa_asr_deprecated/run_m2met_2023_infer.sh
egs/alimeeting/sa_asr_deprecated/utils
copy from egs/alimeeting/sa-asr/utils copy to egs/alimeeting/sa_asr_deprecated/utils
funasr/bin/asr_infer.py
@@ -1651,15 +1651,17 @@
        assert check_argument_types()
        
        # 1. Build ASR model
        from funasr.tasks.sa_asr import ASRTask
        from funasr.tasks.asr import ASRTaskSAASR
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            from funasr.tasks.sa_asr import frontend_choices
            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
funasr/bin/train.py
@@ -299,7 +299,7 @@
        "--freeze_param",
        type=str,
        default=[],
        nargs="*",
        action="append",
        help="Freeze parameters",
    )
funasr/build_utils/build_asr_model.py
@@ -21,8 +21,10 @@
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
@@ -30,6 +32,7 @@
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -90,6 +93,8 @@
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
        rnnt_unified=UnifiedTransducerModel,
        sa_asr=SAASRModel,
    ),
    default="asr",
)
@@ -106,6 +111,27 @@
        chunk_conformer=ConformerChunkEncoder,
    ),
    default="rnn",
)
asr_encoder_choices = ClassChoices(
    "asr_encoder",
    classes=dict(
        conformer=ConformerEncoder,
        transformer=TransformerEncoder,
        rnn=RNNEncoder,
        sanm=SANMEncoder,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        mfcca_enc=MFCCAEncoder,
    ),
    default="rnn",
)
spk_encoder_choices = ClassChoices(
    "spk_encoder",
    classes=dict(
        resnet34_diar=ResNet34Diar,
    ),
    default="resnet34_diar",
)
encoder_choices2 = ClassChoices(
    "encoder2",
@@ -131,6 +157,7 @@
        paraformer_decoder_sanm=ParaformerSANMDecoder,
        paraformer_decoder_san=ParaformerDecoderSAN,
        contextual_paraformer_decoder=ContextualParaformerDecoder,
        sa_decoder=SAAsrTransformerDecoder,
    ),
    default="rnn",
)
@@ -222,6 +249,10 @@
    rnnt_decoder_choices,
    # --joint_network and --joint_network_conf
    joint_network_choices,
    # --asr_encoder and --asr_encoder_conf
    asr_encoder_choices,
    # --spk_encoder and --spk_encoder_conf
    spk_encoder_choices,
]
@@ -239,7 +270,7 @@
    # frontend
    if args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend':
        if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
        else:
            frontend = frontend_class(**args.frontend_conf)
@@ -413,6 +444,33 @@
            joint_network=joint_network,
            **args.model_conf,
        )
    elif args.model == "sa_asr":
        asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
        asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
        spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
        spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=asr_encoder.output_size(),
            **args.decoder_conf,
        )
        ctc = CTC(
            odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
        )
        model_class = model_choices.get_class(args.model)
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            asr_encoder=asr_encoder,
            spk_encoder=spk_encoder,
            decoder=decoder,
            ctc=ctc,
            token_list=token_list,
            **args.model_conf,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
funasr/models/e2e_sa_asr.py
@@ -40,7 +40,7 @@
        yield
class ESPnetASRModel(FunASRModel):
class SAASRModel(FunASRModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
@@ -51,10 +51,8 @@
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            preencoder: Optional[AbsPreEncoder],
            asr_encoder: AbsEncoder,
            spk_encoder: torch.nn.Module,
            postencoder: Optional[AbsPostEncoder],
            decoder: AbsDecoder,
            ctc: CTC,
            spk_weight: float = 0.5,
@@ -89,8 +87,6 @@
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.postencoder = postencoder
        self.asr_encoder = asr_encoder
        self.spk_encoder = spk_encoder
@@ -293,10 +289,6 @@
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +309,6 @@
            encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
        else:
            encoder_out_spk=encoder_out_spk_ori
        # Post-encoder, e.g. NLU
        if self.postencoder is not None:
            encoder_out, encoder_out_lens = self.postencoder(
                encoder_out, encoder_out_lens
            )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -337,7 +324,7 @@
        )
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
            return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
        return encoder_out, encoder_out_lens, encoder_out_spk
funasr/models/frontend/default.py
@@ -2,7 +2,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import humanfriendly
import numpy as np
import torch
@@ -14,6 +14,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.modules.frontends.frontend import Frontend
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.modules.nets_utils import make_pad_mask
class DefaultFrontend(AbsFrontend):
@@ -137,8 +138,6 @@
        return input_stft, feats_lens
class MultiChannelFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -147,9 +146,9 @@
    def __init__(
            self,
            fs: Union[int, str] = 16000,
            n_fft: int = 512,
            win_length: int = None,
            hop_length: int = 128,
            n_fft: int = 400,
            frame_length: int = 25,
            frame_shift: int = 10,
            window: Optional[str] = "hann",
            center: bool = True,
            normalized: bool = False,
@@ -160,10 +159,10 @@
            htk: bool = False,
            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
            apply_stft: bool = True,
            frame_length: int = None,
            frame_shift: int = None,
            lfr_m: int = None,
            lfr_n: int = None,
            use_channel: int = None,
            lfr_m: int = 1,
            lfr_n: int = 1,
            cmvn_file: str = None
    ):
        assert check_argument_types()
        super().__init__()
@@ -172,13 +171,14 @@
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        self.hop_length = hop_length
        self.win_length = frame_length * 16
        self.hop_length = frame_shift * 16
        if apply_stft:
            self.stft = Stft(
                n_fft=n_fft,
                win_length=win_length,
                hop_length=hop_length,
                win_length=self.win_length,
                hop_length=self.hop_length,
                center=center,
                window=window,
                normalized=normalized,
@@ -202,7 +202,17 @@
            htk=htk,
        )
        self.n_mels = n_mels
        self.frontend_type = "multichannelfrontend"
        self.frontend_type = "default"
        self.use_channel = use_channel
        if self.use_channel is not None:
            logging.info("use the channel %d" % (self.use_channel))
        else:
            logging.info("random select channel")
        self.cmvn_file = cmvn_file
        if self.cmvn_file is not None:
            mean, std = self._load_cmvn(self.cmvn_file)
            self.register_buffer("mean", torch.from_numpy(mean))
            self.register_buffer("std", torch.from_numpy(std))
    def output_size(self) -> int:
        return self.n_mels
@@ -215,16 +225,29 @@
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
        else:
            if isinstance(input, ComplexTensor):
                input_stft = input
            else:
                input_stft = ComplexTensor(input[..., 0], input[..., 1])
            input_stft = ComplexTensor(input[..., 0], input[..., 1])
            feats_lens = input_lengths
        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)
        # 3. [Multi channel case]: Select a channel
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                if self.use_channel is not None:
                    input_stft = input_stft[:, :, self.use_channel, :]
                else:
                    # Select 1ch randomly
                    ch = np.random.randint(input_stft.size(2))
                    input_stft = input_stft[:, :, ch, :]
            else:
                # Use the first channel
                input_stft = input_stft[:, :, 0, :]
        # 4. STFT -> Power spectrum
        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
        input_power = input_stft.real ** 2 + input_stft.imag ** 2
@@ -233,18 +256,27 @@
        # input_power: (Batch, [Channel,] Length, Freq)
        #       -> input_feats: (Batch, Length, Dim)
        input_feats, _ = self.logmel(input_power, feats_lens)
        bt = input_feats.size(0)
        if input_feats.dim() ==4:
            channel_size = input_feats.size(2)
            # batch * channel * T * D
            #pdb.set_trace()
            input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
            # input_feats = input_feats.transpose(1,2)
            # batch * channel
            feats_lens = feats_lens.repeat(1,channel_size).squeeze()
        else:
            channel_size = 1
        return input_feats, feats_lens, channel_size
        # 6. Apply CMVN
        if self.cmvn_file is not None:
            if feats_lens is None:
                feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
            self.mean = self.mean.to(input_feats.device, input_feats.dtype)
            self.std = self.std.to(input_feats.device, input_feats.dtype)
            mask = make_pad_mask(feats_lens, input_feats, 1)
            if input_feats.requires_grad:
                input_feats = input_feats + self.mean
            else:
                input_feats += self.mean
            if input_feats.requires_grad:
                input_feats = input_feats.masked_fill(mask, 0.0)
            else:
                input_feats.masked_fill_(mask, 0.0)
            input_feats *= self.std
        return input_feats, feats_lens
    def _compute_stft(
            self, input: torch.Tensor, input_lengths: torch.Tensor
@@ -258,4 +290,27 @@
        # Change torch.Tensor to ComplexTensor
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
        return input_stft, feats_lens
        return input_stft, feats_lens
    def _load_cmvn(self, cmvn_file):
        with open(cmvn_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        means_list = []
        vars_list = []
        for i in range(len(lines)):
            line_item = lines[i].split()
            if line_item[0] == '<AddShift>':
                line_item = lines[i + 1].split()
                if line_item[0] == '<LearnRateCoef>':
                    add_shift_line = line_item[3:(len(line_item) - 1)]
                    means_list = list(add_shift_line)
                    continue
            elif line_item[0] == '<Rescale>':
                line_item = lines[i + 1].split()
                if line_item[0] == '<LearnRateCoef>':
                    rescale_line = line_item[3:(len(line_item) - 1)]
                    vars_list = list(rescale_line)
                    continue
        means = np.array(means_list).astype(np.float)
        vars = np.array(vars_list).astype(np.float)
        return means, vars
funasr/tasks/asr.py
@@ -38,6 +38,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
@@ -45,6 +46,7 @@
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -54,6 +56,7 @@
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@@ -134,6 +137,7 @@
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
        rnnt_unified=UnifiedTransducerModel,
        sa_asr=SAASRModel,
    ),
    type_check=FunASRModel,
    default="asr",
@@ -175,6 +179,27 @@
    type_check=AbsEncoder,
    default="rnn",
)
asr_encoder_choices = ClassChoices(
    "asr_encoder",
    classes=dict(
        conformer=ConformerEncoder,
        transformer=TransformerEncoder,
        rnn=RNNEncoder,
        sanm=SANMEncoder,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        mfcca_enc=MFCCAEncoder,
    ),
    type_check=AbsEncoder,
    default="rnn",
)
spk_encoder_choices = ClassChoices(
    "spk_encoder",
    classes=dict(
        resnet34_diar=ResNet34Diar,
    ),
    default="resnet34_diar",
)
postencoder_choices = ClassChoices(
    name="postencoder",
    classes=dict(
@@ -197,6 +222,7 @@
        paraformer_decoder_sanm=ParaformerSANMDecoder,
        paraformer_decoder_san=ParaformerDecoderSAN,
        contextual_paraformer_decoder=ContextualParaformerDecoder,
        sa_decoder=SAAsrTransformerDecoder,
    ),
    type_check=AbsDecoder,
    default="rnn",
@@ -328,6 +354,12 @@
            type=str2bool,
            default=True,
            help="whether to split text using <space>",
        )
        group.add_argument(
            "--max_spk_num",
            type=int_or_none,
            default=None,
            help="A text mapping int-id to token",
        )
        group.add_argument(
            "--seg_dict_file",
@@ -1495,3 +1527,123 @@
        #assert check_return_type(model)
        return model
class ASRTaskSAASR(ASRTask):
    # If you need more than one optimizers, change this value
    num_optimizers: int = 1
    # Add variable objects configurations
    class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --specaug and --specaug_conf
        specaug_choices,
        # --normalize and --normalize_conf
        normalize_choices,
        # --model and --model_conf
        model_choices,
        # --preencoder and --preencoder_conf
        preencoder_choices,
        # --encoder and --encoder_conf
        # --asr_encoder and --asr_encoder_conf
        asr_encoder_choices,
        # --spk_encoder and --spk_encoder_conf
        spk_encoder_choices,
        # --decoder and --decoder_conf
        decoder_choices,
    ]
    # If you need to modify train() or eval() procedures, change Trainer class here
    trainer = Trainer
    @classmethod
    def build_model(cls, args: argparse.Namespace):
        assert check_argument_types()
        if isinstance(args.token_list, str):
            with open(args.token_list, encoding="utf-8") as f:
                token_list = [line.rstrip() for line in f]
            # Overwriting token_list to keep it as "portable".
            args.token_list = list(token_list)
        elif isinstance(args.token_list, (tuple, list)):
            token_list = list(args.token_list)
        else:
            raise RuntimeError("token_list must be str or list")
        vocab_size = len(token_list)
        logging.info(f"Vocabulary size: {vocab_size}")
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend' or args.frontend == "multichannelfrontend":
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # 2. Data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None
        # 3. Normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None
        # 5. Encoder
        asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
        asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
        spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
        spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
        # 7. Decoder
        decoder_class = decoder_choices.get_class(args.decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=asr_encoder.output_size(),
            **args.decoder_conf,
        )
        # 8. CTC
        ctc = CTC(
            odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
        )
        # import ipdb;ipdb.set_trace()
        # 9. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            asr_encoder=asr_encoder,
            spk_encoder=spk_encoder,
            decoder=decoder,
            ctc=ctc,
            token_list=token_list,
            **args.model_conf,
        )
        # 10. Initialize
        if args.init is not None:
            initialize(model, args.init)
        assert check_return_type(model)
        return model
funasr/tasks/sa_asr.py
@@ -39,7 +39,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_sa_asr import ESPnetASRModel
from funasr.models.e2e_sa_asr import SAASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
@@ -120,7 +120,7 @@
model_choices = ClassChoices(
    "model",
    classes=dict(
        asr=ESPnetASRModel,
        asr=SAASRModel,
        uniasr=UniASR,
        paraformer=Paraformer,
        paraformer_bert=ParaformerBert,
@@ -620,4 +620,4 @@
            initialize(model, args.init)
        assert check_return_type(model)
        return model
        return model