yhliang
2023-08-10 08ee9e6aacc2e306211d393f6e8ce3a7f3620102
Add modular SA-ASR recipe for M2MeT2.0 (#831)

* add modular saasr

* update readme

* Delete train_paraformer.yaml

* update setup.py

* update setup.py

* update setup.py
1个文件已修改
33个文件已添加
2512 ■■■■■ 已修改文件
egs/alimeeting/modular_sa_asr/README.md 103 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/conf/mfcc_hires.conf 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/cmvn_opts 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/final.raw 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/frame_subsampling_factor 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda.mat 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda_stats 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/post_output.vec 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/srand 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/figure/20230809161919.jpg 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/compute_cpcer.py 103 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/convert_label_to_rttm.py 132 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/extract_profile_from_segments.py 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/finetune.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/format_wav_scp.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/format_wav_scp.sh 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/infer_sond.py 27 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py 83 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/merge_spk_text.py 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/remove_overlap.py 140 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/resegment_data.py 78 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/rttm2segments.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/run_gss.py 139 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/run_wpe.py 153 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/segment_to_lab.py 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/segmentation/detect_speech_activity.sh 225 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/test_gss.py 141 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/local/textgrid_processor.py 316 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/path.sh 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/run_asr.sh 152 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/run_diar.sh 233 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/run_enh.sh 114 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/modular_sa_asr/README.md
New file
@@ -0,0 +1,103 @@
# Get Started
This is an official modular SA-ASR system used in M2MeT 2.0 challenge. We developed this system based on various pre-trained models after the challenge and reach the ***SOTA***(until 2023.8.9) performance on the AliMeeting *Test_2023* set. You can also transcribe your own dataset by preparing it into the specific format shown in
# Dependency
To run this receipe, you should install [Kaldi](https://github.com/kaldi-asr/kaldi) and set the `KALDI_ROOT` in `path.sh`.
```shell
export KALDI_ROOT=/your_kaldi_path
```
We use the [VBx](https://github.com/BUTSpeechFIT/VBx) to provide initial diarization result to SOND and [dscore](https://github.com/nryant/dscore.git) to compute the DER. You should clone them before running this receipe.
```shell
$ mkdir VBx && cd VBx
$ git init
$ git remote add origin https://github.com/BUTSpeechFIT/VBx.git
$ git config core.sparsecheckout true
$ echo "VBx/*" >> .git/info/sparse-checkout
$ git pull origin master
$ mv VBx/* .
$ cd ..
$ git clone https://github.com/nryant/dscore.git
```
We use the [pb_chime5](https://github.com/fgnt/pb_chime5) to perform GSS. So you should install the dependencies of this repo using the following command.
```shell
$ git clone https://github.com/fgnt/pb_chime5.git
$ cd pb_chime5
$ git submodule init
$ git submodule update
$ pip install -e pb_bss/
$ pip install -e .
```
# Infer on the AliMeeting Test_2023 set
We follow the workflow shown below.
<div align="left"><img src="figure/20230809161919.jpg" width="500"/>
First you should set the `DATA_SOURCE` in `path.sh` to the data path. Your data path should be organized as follow:
```shell
Test_2023_Ali_far_release
|—— audio_dir/
|    |—— R1014_M1710.wav
|    |—— R1014_M1750.wav
|    |—— ...
|—— textgrid_dir/
|    |—— R1014_M1710.textgrid
|    |—— R1014_M1750.textgrid
|    |—— ...
|—— wav.scp
|—— segments
```
Then you can do speaker diarization with following command.
```shell
$ bash run_diar.sh
```
After diarization, you can check the result at the last line of `data/Test_2023_Ali_far_sond/dia_outputs/dia_result`. You should get a DER about 1.51%.
When you get the similar diarization result with us, then you can do the WPE and GSS using the following command.
```shell
$ bash run_enh.sh 8
```
The number 8 should be replaced with the channel number of your dataset. Here we use the AliMeeting corpus which has 8 channels.
Finally, you can decode the processed audio with the pre-trained ASR model directly using the flollowing commands.
```shell
$ bash run_asr.sh --stage 0 --stop-stage 1
$ bash run_asr.sh --stage 3 --stop-stage 3
```
The ASR result is saved at `./speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/decode_Test_2023_Ali_far_wpegss/text_cpcer`.
# Infer on the AliMeeting Test_2023 set after finetune
You can finetune the pre-trained ASR model with the AliMeeting train set to obtain a further reduction on the cpCER. To infer on the AliMeeting Test 2023 set after finetuning, you can run this commands after the train set is processed with WPE and GSS mentioned above.
 ```shell
$ bash run_asr.sh --stage 2 --stop-stage 3
```
# Infer with your own dataset
We also support infer with your own dataset. Your dataset should be organized as above. The `wav.scp` and `segments` file should format as:
 ```shell
# wav.scp
sessionA wav_path/wav_name_A.wav
sessionB wav_path/wav_name_B.wav
sessionC wav_path/wav_name_C.wav
...
# segments
sessionA-start_time-end_time sessionA start_time end_time
sessionB-start_time-end_time sessionA start_time end_time
sessionC-start_time-end_time sessionA start_time end_time
...
```
Then you should set the `DATA_SOURCE` and `DATA_NAME` in `path.sh`. The rest of the process is the same as [Infer on the AliMeeting Test_2023 set](#infer-on-the-alimeeting-test_2023-set).
# Result
|                |VBx DER(%)     | SOND DER(%)|cp-CER(%)  |
|:---------------|:------------:|:------------:|----------:|
|before finetune  |  16.87       |  1.51     |  10.18    |
|after finetune |     16.87    |   1.51   |          |
egs/alimeeting/modular_sa_asr/conf/mfcc_hires.conf
New file
@@ -0,0 +1,11 @@
# config for high-resolution MFCC features, intended for neural network training.
# Note: we keep all cepstra, so it has the same info as filterbank features,
# but MFCC is more easily compressible (because less correlated) which is why
# we prefer this method.
--use-energy=false   # use average of log energy, not energy.
--sample-frequency=16000
--num-mel-bins=40
--num-ceps=40
--low-freq=40
--high-freq=-400
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/cmvn_opts
New file
@@ -0,0 +1 @@
--norm-means=false --norm-vars=false
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/final.raw
Binary files differ
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/frame_subsampling_factor
New file
@@ -0,0 +1 @@
3
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda.mat
Binary files differ
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/lda_stats
Binary files differ
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/post_output.vec
New file
@@ -0,0 +1 @@
 [ 30 2 1 ]
egs/alimeeting/modular_sa_asr/exp/segmentation_1a/tdnn_stats_sad_1a/srand
New file
@@ -0,0 +1 @@
0
egs/alimeeting/modular_sa_asr/figure/20230809161919.jpg
egs/alimeeting/modular_sa_asr/local/compute_cpcer.py
New file
@@ -0,0 +1,103 @@
import editdistance
import sys
import os
from itertools import permutations
def load_transcripts(file_path):
    trans_list = []
    for one_line in open(file_path, "rt"):
        meeting_id, trans = one_line.strip().split(" ")
        trans_list.append((meeting_id.strip(), trans.strip()))
    return trans_list
def calc_spk_trans(trans):
    spk_trans_ = [x.strip() for x in trans.split("$")]
    spk_trans = []
    for i in range(len(spk_trans_)):
        spk_trans.append((str(i), spk_trans_[i]))
    return spk_trans
def calc_cer(ref_trans, hyp_trans):
    ref_spk_trans = calc_spk_trans(ref_trans)
    hyp_spk_trans = calc_spk_trans(hyp_trans)
    ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
    num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
    ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
    hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
    errors, counts, permutes = [], [], []
    min_error = 0
    cost_dict = {}
    for perm in permutations(range(num_spk)):
        flag = True
        p_err, p_count = 0, 0
        for idx, p in enumerate(perm):
            if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
                flag = False
                break
            cost_key = "{}-{}".format(idx, p)
            if cost_key in cost_dict:
                _e = cost_dict[cost_key]
            else:
                _e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
                cost_dict[cost_key] = _e
            if _e > min_error > 0:
                flag = False
                break
            p_err += _e
            p_count += len(ref_spk_trans[idx][1])
        if flag:
            if p_err < min_error or min_error == 0:
                min_error = p_err
            errors.append(p_err)
            counts.append(p_count)
            permutes.append(perm)
    sd_cer = [(err, cnt, err/cnt, permute)
              for err, cnt, permute in zip(errors, counts, permutes)]
    best_rst = min(sd_cer, key=lambda x: x[2])
    return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
def main():
    ref=sys.argv[1]
    hyp=sys.argv[2]
    result_path="/".join(hyp.split("/")[:-1]) + "/text_cpcer"
    ref_list = load_transcripts(ref)
    hyp_list = load_transcripts(hyp)
    result_file = open(result_path,'w')
    record_2_spk = [0, 0]
    record_3_spk = [0, 0]
    record_4_spk = [0, 0]
    error, count = 0, 0
    for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
        assert ref_id == hyp_id
        mid = ref_id
        dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
        error, count = error + dist, count + length
        result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
        ref_spk = len(ref_trans.split("$"))
        hyp_spk = len(hyp_trans.split("$"))
        if ref_spk == 2:
            record_2_spk[0] += dist
            record_2_spk[1] += length
        elif ref_spk == 3:
            record_3_spk[0] += dist
            record_3_spk[1] += length
        else:
            record_4_spk[0] += dist
            record_4_spk[1] += length
    print(record_2_spk[0]/record_2_spk[1]*100.0)
    print(record_3_spk[0]/record_3_spk[1]*100.0)
    print(record_4_spk[0]/record_4_spk[1]*100.0)
    result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
    result_file.close()
if __name__ == '__main__':
    main()
egs/alimeeting/modular_sa_asr/local/convert_label_to_rttm.py
New file
@@ -0,0 +1,132 @@
import os
from funasr.utils.job_runner import MultiProcessRunnerV3
import numpy as np
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
from collections import OrderedDict
from tqdm import tqdm
from scipy.ndimage import median_filter
class MyRunner(MultiProcessRunnerV3):
    def prepare(self, parser):
        parser.add_argument("label_txt", type=str)
        parser.add_argument("map_scp", type=str)
        parser.add_argument("out_rttm", type=str)
        parser.add_argument("--n_spk", type=int, default=4)
        parser.add_argument("--chunk_len", type=int, default=1600)
        parser.add_argument("--shift_len", type=int, default=400)
        parser.add_argument("--ignore_len", type=int, default=5)
        parser.add_argument("--smooth_size", type=int, default=7)
        parser.add_argument("--vote_prob", type=float, default=0.5)
        args = parser.parse_args()
        if not os.path.exists(os.path.dirname(args.out_rttm)):
            os.makedirs(os.path.dirname(args.out_rttm))
        utt2labels = load_scp_as_list(args.label_txt, 'list')
        utt2labels = sorted(utt2labels, key=lambda x: x[0])
        meeting2map = load_scp_as_dict(args.map_scp)
        meeting2labels = OrderedDict()
        for utt_id, chunk_label in utt2labels:
            mid = utt_id.split("-")[0]
            if mid not in meeting2labels:
                meeting2labels[mid] = []
            meeting2labels[mid].append(chunk_label)
        task_list = [(mid, labels, meeting2map[mid]) for mid, labels in meeting2labels.items()]
        return task_list, None, args
    def post(self, result_list, args):
        with open(args.out_rttm, "wt") as fd:
            for results in result_list:
                fd.writelines(results)
def int2vec(x, vec_dim=8, dtype=np.int):
    b = ('{:0' + str(vec_dim) + 'b}').format(x)
    # little-endian order: lower bit first
    return (np.array(list(b)[::-1]) == '1').astype(dtype)
def seq2arr(seq, vec_dim=8):
    return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def sample2ms(sample, sr=16000):
    return int(float(sample) / sr * 100)
def calc_multi_labels(chunk_label_list, chunk_len, shift_len, n_spk, vote_prob=0.5):
    n_chunk = len(chunk_label_list)
    last_chunk_valid_frame = len(chunk_label_list[-1]) - (chunk_len - shift_len)
    n_frame = (n_chunk - 2) * shift_len + chunk_len + last_chunk_valid_frame
    multi_labels = np.zeros((n_frame, n_spk), dtype=float)
    weight = np.zeros((n_frame, 1), dtype=float)
    for i in range(n_chunk):
        raw_label = chunk_label_list[i]
        for k in range(len(raw_label)):
            if raw_label[k] == '<unk>':
                raw_label[k] = raw_label[k-1] if k > 0 else '0'
        chunk_multi_label = seq2arr(raw_label, n_spk)
        chunk_len = chunk_multi_label.shape[0]
        multi_labels[i*shift_len:i*shift_len+chunk_len, :] += chunk_multi_label
        weight[i*shift_len:i*shift_len+chunk_len, :] += 1
    multi_labels = multi_labels / weight  # normalizing vote
    multi_labels = (multi_labels > vote_prob).astype(int)  # voting results
    return multi_labels
def calc_spk_turns(label_arr, spk_list):
    turn_list = []
    length = label_arr.shape[0]
    n_spk = label_arr.shape[1]
    for k in range(n_spk):
        if spk_list[k] == "None":
            continue
        in_utt = False
        start = 0
        for i in range(length):
            if label_arr[i, k] == 1 and in_utt is False:
                start = i
                in_utt = True
            if label_arr[i, k] == 0 and in_utt is True:
                turn_list.append([spk_list[k], start, i - start])
                in_utt = False
        if in_utt:
            turn_list.append([spk_list[k], start, length - start])
    return turn_list
def smooth_multi_labels(multi_label, win_len):
    multi_label = median_filter(multi_label, (win_len, 1), mode="constant", cval=0.0).astype(int)
    return multi_label
def process(task_args):
    _, task_list, _, args = task_args
    spk_list = ["spk{}".format(i+1) for i in range(args.n_spk)]
    template = "SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>\n"
    results = []
    for mid, chunk_label_list, map_file_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
        utt2map = load_scp_as_list(map_file_path, 'list')
        multi_labels = calc_multi_labels(chunk_label_list, args.chunk_len, args.shift_len, args.n_spk, args.vote_prob)
        multi_labels = smooth_multi_labels(multi_labels, args.smooth_size)
        org_len = sample2ms(int(utt2map[-1][1][1]), args.sr)
        org_multi_labels = np.zeros((org_len, args.n_spk))
        for seg_id, [org_st, org_ed, st, ed] in utt2map:
            org_st, org_dur = sample2ms(int(org_st), args.sr), sample2ms(int(org_ed) - int(org_st), args.sr)
            st, dur = sample2ms(int(st), args.sr), sample2ms(int(ed) - int(st), args.sr)
            ll = min(org_multi_labels[org_st: org_st+org_dur, :].shape[0], multi_labels[st: st+dur, :].shape[0])
            org_multi_labels[org_st: org_st+ll, :] = multi_labels[st: st+ll, :]
        spk_turns = calc_spk_turns(org_multi_labels, spk_list)
        spk_turns = sorted(spk_turns, key=lambda x: x[1])
        for spk, st, dur in spk_turns:
            # TODO: handle the leak of segments at the change points
            if dur > args.ignore_len:
                results.append(template.format(mid, float(st)/100, float(dur)/100, spk))
    return results
if __name__ == '__main__':
    my_runner = MyRunner(process)
    my_runner.run()
egs/alimeeting/modular_sa_asr/local/extract_profile_from_segments.py
New file
@@ -0,0 +1,104 @@
import codecs
import sys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import numpy as np
import os
import soundfile
data_path = sys.argv[1]
segment_file_path = data_path + "/segments_nooverlap"
utt2spk_file_path = data_path + "/utt2spk_nooverlap"
wav_scp_path = data_path + "/wav.scp"
cluster_emb_dir = data_path + '/cluster_embedding/'
os.system("mkdir -p " + cluster_emb_dir)
cluster_profile_dir = data_path + '/cluster_profile_zeropadding16/'
os.system('mkdir -p ' + cluster_profile_dir)
utt2spk = {}
spk2seg = {}
with codecs.open(utt2spk_file_path, "r", "utf-8") as f1:
    with codecs.open(segment_file_path, "r", "utf-8") as f2:
        for line in f1.readlines():
            uttid, spkid = line.strip().split(" ")
            utt2spk[uttid] = spkid
        for line in f2.readlines():
            uttid, sessionid, stime, etime = line.strip().split(" ")
            spkid = utt2spk[uttid]
            if spkid not in spk2seg.keys():
                spk2seg[spkid] = [(int(float(stime) * 16000), int(float(etime) * 16000) - int(float(stime) * 16000))]
            else:
                spk2seg[spkid].append((int(float(stime) * 16000), int(float(etime) * 16000) - int(float(stime) * 16000)))
inference_sv_pipline = pipeline(
    task=Tasks.speaker_verification,
    model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch',
    device='gpu'
)
wav_dict = {}
with codecs.open(wav_scp_path, "r", "utf-8") as fi:
    with codecs.open(data_path + "/cluster_embedding.scp", "w", "utf-8") as fo:
        for line in fi.readlines():
            sessionid, wav_path = line.strip().split()
            wav_dict[sessionid] = wav_path
        for spkid, segs in spk2seg.items():
            sessionid = spkid.split("-")[0]
            wav_path = wav_dict[sessionid]
            wav = soundfile.read(wav_path)[0]
            if wav.ndim == 2:
                wav = wav[:, 0]
            all_seg_embedding_list=[]
            for seg in segs:
                if seg[0] < wav.shape[0] - 0.5 * 16000:
                    if seg[1] > wav.shape[0]:
                        cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg[0]: ])["spk_embedding"]
                    else:
                        cur_seg_embedding = inference_sv_pipline(audio_in=wav[seg[0]: seg[0] + seg[1]])["spk_embedding"]
                    all_seg_embedding_list.append(cur_seg_embedding)
            all_seg_embedding = np.vstack(all_seg_embedding_list)
            spk_embedding = np.mean(all_seg_embedding, axis=0)
            np.save(cluster_emb_dir + spkid + '.npy', spk_embedding)
            fo.write(spkid + ' ' + cluster_emb_dir + spkid + '.npy' + '\n')
session2embs = {}
with codecs.open(data_path + "/cluster_embedding.scp", "r", "utf-8") as fi:
    with codecs.open(data_path + "/cluster_profile_zeropadding16.scp", "w", "utf-8") as fo:
        for line in fi.readlines():
            spkid, emb_path = line.strip().split(" ")
            sessionid = spkid.split("-")[0]
            if sessionid not in session2embs.keys():
                session2embs[sessionid] = [emb_path]
            else:
                session2embs[sessionid].append(emb_path)
        for sessionid, embs in session2embs.items():
            emb_list = [np.load(x) for x in embs]
            tmp = []
            for i in range(len(emb_list) - 1):
                flag = True
                for j in range(i + 1, len(emb_list)):
                    cos_sim = emb_list[i].dot(emb_list[j]) / (np.linalg.norm(emb_list[i]) * np.linalg.norm(emb_list[j]))
                    if cos_sim > 0.99:
                        flag = False
                if flag:
                    tmp.append(emb_list[i][np.newaxis, :])
            tmp.append(emb_list[-1][np.newaxis, :])
            emb_list = tmp
            # tmp = []
            # for i in range(len(emb_list)):
            #     for emb in tmp:
            #         cos_sim = emb_list[i].dot(emb_list[j]) / (np.linalg.norm(emb_list[i]) * np.linalg.norm(emb_list[j]))
            #         if cos_sim > 0.99:
            #             flag = False
            #     if flag:
            #         tmp.append(emb_list[i][np.newaxis, :])
            # emb_list = tmp
            for i in range(16 - len(emb_list)):
                emb_list.append(np.zeros((1, 256)))
            emb = np.concatenate(emb_list, axis=0)
            save_path = cluster_profile_dir + sessionid + ".npy"
            np.save(save_path, emb)
            fo.write("%s %s\n" % (sessionid, save_path))
egs/alimeeting/modular_sa_asr/local/finetune.py
New file
@@ -0,0 +1,34 @@
import os
import sys
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.msdatasets.audio.asr_dataset import ASRDataset
def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
    kwargs = dict(
        model=params.model,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
if __name__ == '__main__':
    from funasr.utils.modelscope_param import modelscope_args
    params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
    params.output_dir = "./checkpoint"                      # 模型保存路径
    params.data_path = "./data"      # 数据路径,可以为modelscope中已上传数据,也可以是本地数据
    params.dataset_type = "small"                           # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 2000                                # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 20                                   # 最大训练轮数
    params.lr = 0.00005                                     # 设置学习率
    modelscope_finetune(params)
egs/alimeeting/modular_sa_asr/local/format_wav_scp.py
New file
@@ -0,0 +1 @@
../../sa_asr/local/format_wav_scp.py
egs/alimeeting/modular_sa_asr/local/format_wav_scp.sh
New file
@@ -0,0 +1 @@
../../sa_asr/local/format_wav_scp.sh
egs/alimeeting/modular_sa_asr/local/infer_sond.py
New file
@@ -0,0 +1,27 @@
from funasr.bin.diar_inference_launch import inference_launch
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES']='7'
def main():
    diar_config_path = sys.argv[1] if len(sys.argv) > 1 else "sond_fbank.yaml"
    diar_model_path = sys.argv[2] if len(sys.argv) > 2 else "sond.pb"
    input_dir = sys.argv[3] if len(sys.argv) > 3 else "./inputs"
    output_dir = sys.argv[4] if len(sys.argv) > 4 else "./outputs"
    data_path_and_name_and_type = [
        (input_dir + "/wav.scp", "speech", "sound"),
        (input_dir + "/profile.scp", "profile", "npy"),
    ]
    pipeline = inference_launch(
        mode="sond",
        diar_train_config=diar_config_path,
        diar_model_file=diar_model_path,
        output_dir=output_dir,
        num_workers=16,
        ngpu=1,
    )
    pipeline(data_path_and_name_and_type)
if __name__ == '__main__':
    main()
egs/alimeeting/modular_sa_asr/local/make_textgrid_rttm.py
New file
@@ -0,0 +1,58 @@
import argparse
import tqdm
import codecs
import textgrid
import pdb
class Segment(object):
    def __init__(self, uttid, spkr, stime, etime, text):
        self.uttid = uttid
        self.spkr = spkr
        self.stime = round(stime, 2)
        self.etime = round(etime, 2)
        self.text = text
    def change_stime(self, time):
        self.stime = time
    def change_etime(self, time):
        self.etime = time
def main(args):
    tg = textgrid.TextGrid.fromFile(args.input_textgrid_file)
    segments = []
    spk = {}
    num_spk = 1
    uttid = args.uttid
    for i in range(tg.__len__()):
        for j in range(tg[i].__len__()):
            if tg[i][j].mark:
                if tg[i].name not in spk:
                    spk[tg[i].name] = num_spk
                    num_spk += 1
                segments.append(
                    Segment(
                        uttid,
                        spk[tg[i].name],
                        tg[i][j].minTime,
                        tg[i][j].maxTime,
                        tg[i][j].mark.strip(),
                    )
                )
    segments = sorted(segments, key=lambda x: x.stime)
    rttm_file = codecs.open(args.output_rttm_file, "w", "utf-8")
    for i in range(len(segments)):
        fmt = "SPEAKER {:s} 1 {:.2f} {:.2f} <NA> <NA> {:s} <NA> <NA>"
        #pdb.set_trace()
        rttm_file.write(f"{fmt.format(segments[i].uttid, float(segments[i].stime), float(segments[i].etime) - float(segments[i].stime), str(segments[i].spkr))}\n")
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Make rttm for true label")
    parser.add_argument("--input_textgrid_file", required=True, help="The textgrid file")
    parser.add_argument("--output_rttm_file", required=True, help="The output rttm file")
    parser.add_argument("--uttid", required=True, help="The utt id of the file")
    args = parser.parse_args()
    main(args)
egs/alimeeting/modular_sa_asr/local/meeting_speaker_number_process.py
New file
@@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
"""
Process the textgrid files
"""
import argparse
import codecs
from distutils.util import strtobool
from pathlib import Path
import textgrid
import pdb
class Segment(object):
    def __init__(self, uttid, spkr, stime, etime, text):
        self.uttid = uttid
        self.spkr = spkr
        self.stime = round(stime, 2)
        self.etime = round(etime, 2)
        self.text = text
    def change_stime(self, time):
        self.stime = time
    def change_etime(self, time):
        self.etime = time
def get_args():
    parser = argparse.ArgumentParser(description="process the textgrid files")
    parser.add_argument("--path", type=str, required=True, help="textgrid path")
    parser.add_argument("--label_path", type=str, required=True, help="label rttm file path")
    parser.add_argument("--predict_path", type=str, required=True, help="predict rttm file path")
    args = parser.parse_args()
    return args
def main(args):
    textgrid_flist = codecs.open(Path(args.path)/"uttid_textgrid.flist", "r", "utf-8")
    # parse the textgrid file for each utterance
    speaker2_uttidset = []
    speaker3_uttidset = []
    speaker4_uttidset = []
    for line in textgrid_flist:
        uttid ,textgrid_file = line.strip().split("\t")
        tg = textgrid.TextGrid()
        tg.read(textgrid_file)
        num_speaker = len(tg)
        if num_speaker ==2:
            speaker2_uttidset.append(uttid)
        elif num_speaker ==3:
            speaker3_uttidset.append(uttid)
        elif num_speaker ==4:
            speaker4_uttidset.append(uttid)
    textgrid_flist.close()
    speaker2_id_label = codecs.open(Path(args.label_path) / "speaker2_id", "w", "utf-8")
    speaker2_id_predict = codecs.open(Path(args.predict_path) / "speaker2_id", "w", "utf-8")
    speaker3_id_label = codecs.open(Path(args.label_path) / "speaker3_id", "w", "utf-8")
    speaker3_id_predict = codecs.open(Path(args.predict_path) / "speaker3_id", "w", "utf-8")
    speaker4_id_label = codecs.open(Path(args.label_path) / "speaker4_id", "w", "utf-8")
    speaker4_id_predict = codecs.open(Path(args.predict_path) / "speaker4_id", "w", "utf-8")
    for i in range(len(speaker2_uttidset)):
        speaker2_id_label.write("%s\n" % (args.label_path+"/"+speaker2_uttidset[i]+".rttm"))
        speaker2_id_predict.write("%s\n" % (args.predict_path+"/"+speaker2_uttidset[i]+".rttm"))
    for i in range(len(speaker3_uttidset)):
        speaker3_id_label.write("%s\n" % (args.label_path+"/"+speaker3_uttidset[i]+".rttm"))
        speaker3_id_predict.write("%s\n" % (args.predict_path+"/"+speaker3_uttidset[i]+".rttm"))
    for i in range(len(speaker4_uttidset)):
        speaker4_id_label.write("%s\n" % (args.label_path+"/"+speaker4_uttidset[i]+".rttm"))
        speaker4_id_predict.write("%s\n" % (args.predict_path+"/"+speaker4_uttidset[i]+".rttm"))
    speaker2_id_label.close()
    speaker2_id_predict.close()
    speaker3_id_label.close()
    speaker3_id_predict.close()
    speaker4_id_label.close()
    speaker4_id_predict.close()
if __name__ == "__main__":
    args = get_args()
    main(args)
egs/alimeeting/modular_sa_asr/local/merge_spk_text.py
New file
@@ -0,0 +1,52 @@
import sys
import codecs
import zhconv
decode_result = sys.argv[1]
utt2spk_file = sys.argv[2]
merged_result = "/".join(decode_result.split("/")[:-1]) + "/text_merge"
utt2text = {}
utt2spk = {}
spk2texts = {}
spk2text = {}
meeting2text = {}
with codecs.open(decode_result, "r", "utf-8") as f1:
    with codecs.open(utt2spk_file, "r", "utf-8") as f2:
        for line in f1.readlines():
            try:
                line_list = line.strip().split()
                uttid = line_list[0]
                text = "".join(line_list[1:])
            except:
                continue
            utt2text[uttid] = text
        for line in f2.readlines():
            uttid, spkid = line.strip().split()
            utt2spk[uttid] = spkid
for utt, text in utt2text.items():
    spk = utt2spk[utt]
    stime = int(utt.split("-")[-2])
    if spk in spk2texts.keys():
        spk2texts[spk].append([stime, text])
    else:
        spk2texts[spk] = [[stime, text]]
for spk, texts in spk2texts.items():
    texts = sorted(texts, key=lambda x: x[0])
    text = "".join([x[1] for x in texts])
    spk2text[spk] = text
with codecs.open(merged_result, "w", "utf-8") as f:
    for spk, text in spk2text.items():
        # meeting = spk.split("-")[2]
        meeting = spk.split("-")[0]
        if meeting in meeting2text.keys():
            meeting2text[meeting] = meeting2text[meeting] + "$" + text
        else:
            meeting2text[meeting] = text
    for meeting, text in meeting2text.items():
        f.write("%s %s\n" % (meeting, text))
egs/alimeeting/modular_sa_asr/local/remove_overlap.py
New file
@@ -0,0 +1,140 @@
import sys
import pdb
import codecs
import os
input_segments_file = sys.argv[1]
input_utt2spk_file = sys.argv[2]
output_segments_file = sys.argv[3]
output_utt2spk_file = sys.argv[4]
threshold = sys.argv[5]
class Segment(object):
    def __init__(self, baseid, spkid, meetingid, stime, etime, uttid=None):
        self.baseid = baseid
        self.spkid = spkid
        self.meetingid = meetingid
        self.stime = round(stime, 2)
        self.etime = round(etime, 2)
        self.uttid = uttid
        self.dur = self.etime - self.stime
        if self.uttid is None:
            self.uttid = "%s-%s-%07d-%07d" % (
                self.baseid,
                self.spkid,
                self.stime * 100,
                self.etime * 100,
            )
def cut(cur_max_end_time, seg_list, cur_seg, next_c):
    global out_segment_dict
    if next_c == len(seg_list):
        single_stime = max(cur_max_end_time, cur_seg.stime)
        single_etime = cur_seg.etime
        if single_stime < single_etime and single_etime - single_stime > float(threshold):
            # only save segment which duration more than threshold for sv's accuracy
            if cur_seg.spkid not in out_segment_dict.keys():
                out_segment_dict[cur_seg.spkid] = [
                    Segment(
                        cur_seg.baseid,
                        cur_seg.spkid,
                        cur_seg.meetingid,
                        single_stime,
                        single_etime,
                    )]
            else:
                out_segment_dict[cur_seg.spkid].append(
                    Segment(
                        cur_seg.baseid,
                        cur_seg.spkid,
                        cur_seg.meetingid,
                        single_stime,
                        single_etime,
                    )
                )
    else:
        next_seg = seg_list[next_c]
        single_stime = max(cur_max_end_time, cur_seg.stime)
        single_etime = min(cur_seg.etime, next_seg.stime)
        if single_stime < single_etime and single_etime - single_stime > float(threshold):
            if cur_seg.spkid not in out_segment_dict.keys():
                out_segment_dict[cur_seg.spkid] = [
                    Segment(
                        cur_seg.baseid,
                        cur_seg.spkid,
                        cur_seg.meetingid,
                        single_stime,
                        single_etime,
                )]
            else:
                out_segment_dict[cur_seg.spkid].append(
                    Segment(
                        cur_seg.baseid,
                        cur_seg.spkid,
                        cur_seg.meetingid,
                        single_stime,
                        single_etime,
                    )
                )
        if cur_seg.etime > next_seg.etime:
            cut(max(cur_max_end_time, next_seg.etime), seg_list, cur_seg, next_c + 1)
meeting2seg = {}
utt2spk = {}
i = 0
with codecs.open(input_utt2spk_file, "r", "utf-8") as f:
    for line in f.readlines():
        utt, spk = line.strip().split()
        utt2spk[utt] = spk
with codecs.open(input_segments_file, "r", "utf-8") as f:
    for line in f.readlines():
        i += 1
        uttid, meetingid, stime, etime = line.strip().split(" ")
        spkid = utt2spk[uttid].split("-")[1]
        baseid = meetingid
        one_seg = Segment(baseid, spkid, meetingid, float(stime), float(etime))
        if one_seg.meetingid not in meeting2seg.keys():
            meeting2seg[one_seg.meetingid] = [one_seg]
        else:
            meeting2seg[one_seg.meetingid].append(one_seg)
out_segment_dict = {}
for k, v in meeting2seg.items():
    meeting2seg[k] = sorted(v, key=lambda x: x.stime)
    cur_max_end_time = 0
    for i in range(len(v)):
        cur_seg = meeting2seg[k][i]
        cut(cur_max_end_time, meeting2seg[k], cur_seg, i + 1)
        cur_max_end_time = max(cur_max_end_time, cur_seg.etime)
out_segment_list = []
for k, v in out_segment_dict.items():
    out_segment_list.extend(out_segment_dict[k])
with codecs.open(output_segments_file, "w", "utf-8") as f_seg:
    with codecs.open(output_utt2spk_file, "w", "utf-8") as f_utt2spk:
        for out_seg in out_segment_list:
            f_seg.write(
                "%s %s %.2f %.2f\n"
                % (
                    out_seg.uttid,
                    out_seg.meetingid,
                    out_seg.stime,
                    out_seg.etime,
                )
            )
            f_utt2spk.write(
                "%s %s-%s\n"
                % (
                    out_seg.uttid,
                    out_seg.baseid,
                    out_seg.spkid,
                )
            )
egs/alimeeting/modular_sa_asr/local/resegment_data.py
New file
@@ -0,0 +1,78 @@
import soundfile
import os
import sys
import codecs
import numpy as np
import pdb
segment_file_path = sys.argv[1]
wav_scp_file_path = sys.argv[2]
data_path = sys.argv[3]
wav_save_path = data_path + "/wav/"
os.system("mkdir -p " + wav_save_path)
pos_path = data_path + "/pos_map/"
os.system("mkdir -p " + pos_path)
wav_dict = {}
seg2time = {}
seg2time_new = {}
session2profile = {}
with codecs.open(wav_scp_file_path, "r", "utf-8") as f:
    for line in f.readlines():
        sessionid, wav_path = line.strip().split()
        wav_dict[sessionid] = wav_path
with codecs.open(segment_file_path, "r", "utf-8") as f:
    for line in f.readlines():
        _, sessionid, stime, etime = line.strip().split()
        if sessionid not in seg2time.keys():
            seg2time[sessionid] = [(int(16000 * float(stime)), int(16000 * float(etime)))]
        else:
            seg2time[sessionid].append((int(16000 * float(stime)), int(16000 * float(etime))))
with codecs.open(data_path + "/map.scp", "w", "utf-8") as f1:
    for sessionid, seg_times in seg2time.items():
        seg2time_new[sessionid] = []
        last_time = 0
        with codecs.open(pos_path + sessionid + ".pos", "w", "utf-8") as f2:
            for seg_time in seg_times:
                tmp = seg_time[0] - last_time
                cur_seg = (seg_time[0] - tmp, seg_time[1] - tmp)
                seg2time_new[sessionid].append((seg_time[0] - last_time, seg_time[1] - last_time))
                last_time = cur_seg[1]
                f2.write("%s-%07d-%07d %d %d %d %d\n" % (sessionid, seg_time[0]/160, seg_time[1]/160, seg_time[0], seg_time[1], cur_seg[0], cur_seg[1]))
        f1.write("%s %s\n" % (sessionid, pos_path + sessionid + ".pos"))
with codecs.open(data_path + "/cluster_profile_zeropadding16.scp", "r", "utf-8") as f:
    for line in f.readlines():
        session, path = line.strip().split()
        session2profile[session] = path
with codecs.open(data_path + "/wav.scp", "w", "utf-8") as f1:
    with codecs.open(data_path + "/profile.scp", "w", "utf-8") as f2:
        for sessionid, wav_path in wav_dict.items():
            wav = soundfile.read(wav_path)[0]
            if wav.ndim == 2:
                    wav = wav[:, 0]
            seg_list = [wav[seg[0]: seg[1]] for seg in seg2time[sessionid]]
            wav_new = np.concatenate(seg_list, axis=0)
            cur_time = 0
            flag = True
            while flag:
                start = cur_time
                end = cur_time + 256000
                if end < wav_new.shape[0]:
                    cur_wav = wav_new[start: end]
                else:
                    cur_wav = wav_new[start: ]
                    flag = False
                cur_time = cur_time + 64000
                wav_name = "%s-%07d_%07d.wav" % (sessionid, start/160, end/160)
                soundfile.write(wav_save_path + wav_name, cur_wav, 16000)
                f1.write("%s %s\n" % (wav_name, wav_save_path + wav_name))
                f2.write("%s %s\n" % (wav_name, session2profile[sessionid]))
egs/alimeeting/modular_sa_asr/local/rttm2segments.py
New file
@@ -0,0 +1,29 @@
import codecs
import sys
rttm_file_path = sys.argv[1]
segment_file_path = sys.argv[2]
mode = sys.argv[3] # 0 for diarization, 1 for asr
meeting2spk = {}
with codecs.open(rttm_file_path, "r", "utf-8") as fi:
    with codecs.open(segment_file_path + "/segments", "w", "utf-8") as f1:
        with codecs.open(segment_file_path + "/utt2spk", "w", "utf-8") as f2:
            for line in fi.readlines():
                _, sessionid, _, stime, dur, _, _, spkid, _, _ = line.strip().split(" ")
                if float(dur) < 0.3:
                    continue
                uttid = "%s-%07d-%07d" % (sessionid, int(float(stime) * 100), int(float(stime) * 100 + float(dur) * 100))
                spkid = "%s-%s" % (sessionid, spkid)
                if int(mode) == 0:
                    f1.write("%s %s %.2f %.2f\n" % (uttid, sessionid, float(stime), float(stime) + float(dur)))
                    f2.write("%s %s\n" % (uttid, spkid))
                elif int(mode) == 1:
                    f1.write("%s %s %.2f %.2f\n" % (uttid, spkid, float(stime), float(stime) + float(dur)))
                    f2.write("%s %s\n" % (uttid, spkid))
                else:
                    exit("mode only support 0 or 1!")
egs/alimeeting/modular_sa_asr/local/run_gss.py
New file
@@ -0,0 +1,139 @@
#!/usr/bin/env python
# -- coding: UTF-8
import argparse
import codecs
import glob
import logging
import os
from nara_wpe.utils import stft, istft
import numpy as np
import scipy.io.wavfile as wf
from tqdm import tqdm
from test_gss import *
def get_parser():
    parser = argparse.ArgumentParser("Doing GSS based enhancement.")
    parser.add_argument(
        "--wav-scp",
        type=str,
        required=True,
        help="Wav scp file for enhancement.",
    )
    parser.add_argument(
        "--segments",
        type=str,
        required=True,
        help="Wav scp file for enhancement.",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        required=True,
        help="Output directory of GSS enhanced data.",
    )
    return parser
def wfread(f):
    fs, data = wf.read(f)
    if data.dtype == np.int16:
        data = np.float32(data) / 32768
    return data, fs
def wfwrite(z, fs, store_path):
    tmpwav = np.int16(z * 32768)
    wf.write(store_path, fs, tmpwav)
def main():
    args = get_parser().parse_args()
    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    stft_window, stft_shift = 512, 256
    gss = GSS(iterations=20, iterations_post=1)
    bf = Beamformer("mvdrSouden_ban", "mask_mul")
    with codecs.open(args.wav_scp, "r") as handle:
        lines_content = handle.readlines()
    wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
    cnt = 0
    session2spk2dur = {}
    with codecs.open(args.segments, "r") as handle:
        for line in handle.readlines():
            uttid, spkid, stime, etime = line.strip().split(" ")
            sessionid = spkid.split("-")[0]
            if sessionid not in session2spk2dur.keys():
                session2spk2dur[sessionid] = {}
            if spkid not in session2spk2dur[sessionid].keys():
                session2spk2dur[sessionid][spkid] = []
            session2spk2dur[sessionid][spkid].append((float(stime), float(etime)))
    # import pdb;pdb.set_trace()
    for wav_idx in tqdm(range(len(wav_lines)), leave=True, desc="0"):
        # get wav files from scp file
        file_list = wav_lines[wav_idx].split(" ")
        sessionid, wav_list = file_list[0], file_list[1:]
        signal_list = []
        time_activity = []
        cnt += 1
        logging.info("Processing {}: {}".format(cnt, wav_list[0]))
        # read all wavs
        for wav in wav_list:
            data, fs = wfread(wav)
            signal_list.append(data)
        try:
            obstft = np.stack(signal_list, axis=0)
        except:
            minlen = min([len(s) for s in signal_list])
            obstft = np.stack([s[:minlen] for s in signal_list])
        wavlen = obstft.shape[1]
        obstft = stft(obstft, stft_window, stft_shift)
        # get activated timestamps and frequencies
        speaker_list = []
        for spk, dur in session2spk2dur[sessionid].items():
            speaker_list.append(spk.split("-")[-1])
            time_activity.append(get_time_activity(dur, wavlen, fs))
        time_activity.append([True] * wavlen)
        frequency_activity = get_frequency_activity(
            time_activity, stft_window, stft_shift
        )
        # import pdb;pdb.set_trace()
        # generate mask
        masks = gss(obstft, frequency_activity)
        masks_bak = masks
        for i in range(masks.shape[0] - 1):
            target_mask = masks[i]
            distortion_mask = np.sum(np.delete(masks, i, axis=0), axis=0)
            xhat = bf(obstft, target_mask=target_mask, distortion_mask=distortion_mask)
            xhat = istft(xhat, stft_window, stft_shift)
            audio_dir = "/".join(wav_list[0].split("/")[:-1])
            store_path = (
                wav_list[0]
                .replace(audio_dir, args.output_dir)
                .replace(".wav", "-{}.wav".format(speaker_list[i]))
            )
            if not os.path.exists(os.path.split(store_path)[0]):
                os.makedirs(os.path.split(store_path)[0], exist_ok=True)
            logging.info("Save wav file {}.".format(store_path))
            wfwrite(xhat, fs, store_path)
            masks = masks_bak
if __name__ == "__main__":
    main()
egs/alimeeting/modular_sa_asr/local/run_wpe.py
New file
@@ -0,0 +1,153 @@
#!/usr/bin/env python
# _*_ coding: UTF-8 _*_
import argparse
import codecs
import os
import logging
from multiprocessing import Pool
import numpy as np
import scipy.io.wavfile as wf
from nara_wpe.utils import istft, stft
from nara_wpe.wpe import wpe_v8 as wpe
def wpe_worker(
    wav_scp,
    audio_dir="",
    output_dir="",
    channel=0,
    processing_id=None,
    processing_num=None,
):
    sampling_rate = 16000
    iterations = 5
    stft_options = dict(
        size=512,
        shift=128,
        window_length=None,
        fading=True,
        pad=True,
        symmetric_window=False,
    )
    with codecs.open(wav_scp, "r") as handle:
        lines_content = handle.readlines()
    wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
    for wav_idx in range(len(wav_lines)):
        if processing_id is None:
            processing_token = True
        else:
            if wav_idx % processing_num == processing_id:
                processing_token = True
            else:
                processing_token = False
        if processing_token:
            wav_list = wav_lines[wav_idx].split(" ")
            file_exist = True
            for wav_path in wav_list:
                file_exist = file_exist and os.path.exists(
                    wav_path.replace(audio_dir, output_dir)
                )
                if not file_exist:
                    break
            if not file_exist:
                logging.info("wait to process {} : {}".format(wav_idx, wav_list[0]))
                signal_list = []
                for f in wav_list:
                    _, data = wf.read(f)
                    data = data[:, channel - 1]
                    if data.dtype == np.int16:
                        data = np.float32(data) / 32768
                    signal_list.append(data)
                min_len = len(signal_list[0])
                max_len = len(signal_list[0])
                for i in range(1, len(signal_list)):
                    min_len = min(min_len, len(signal_list[i]))
                    max_len = max(max_len, len(signal_list[i]))
                if min_len != max_len:
                    for i in range(len(signal_list)):
                        signal_list[i] = signal_list[i][:min_len]
                y = np.stack(signal_list, axis=0)
                Y = stft(y, **stft_options).transpose(2, 0, 1)
                Z = wpe(Y, iterations=iterations, statistics_mode="full").transpose(
                    1, 2, 0
                )
                z = istft(Z, size=stft_options["size"], shift=stft_options["shift"])
                for d in range(len(signal_list)):
                    store_path = wav_list[d].replace(audio_dir, output_dir)
                    if not os.path.exists(os.path.split(store_path)[0]):
                        os.makedirs(os.path.split(store_path)[0], exist_ok=True)
                    tmpwav = np.int16(z[d, :] * 32768)
                    wf.write(store_path, sampling_rate, tmpwav)
            else:
                logging.info("file exist {} : {}".format(wav_idx, wav_list[0]))
    return None
def wpe_manager(
    wav_scp, processing_num=1, audio_dir="", output_dir="", channel=1
):
    if processing_num > 1:
        pool = Pool(processes=processing_num)
        for i in range(processing_num):
            pool.apply_async(
                wpe_worker,
                kwds={
                    "wav_scp": wav_scp,
                    "processing_id": i,
                    "processing_num": processing_num,
                    "audio_dir": audio_dir,
                    "output_dir": output_dir,
                },
            )
        pool.close()
        pool.join()
    else:
        wpe_worker(wav_scp, audio_dir=audio_dir, output_dir=output_dir, channel=channel)
    return None
if __name__ == "__main__":
    parser = argparse.ArgumentParser("run_wpe")
    parser.add_argument(
        "--wav-scp",
        type=str,
        required=True,
        help="Path pf wav scp file",
    )
    parser.add_argument(
        "--audio-dir",
        type=str,
        required=True,
        help="Directory of input audio files",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        required=True,
        help="Output directory of WPE enhanced audio files",
    )
    parser.add_argument(
        "--channel",
        type=str,
        required=True,
        help="Channel number of input audio",
    )
    parser.add_argument("--nj", type=int, default="1", help="number of process")
    args = parser.parse_args()
    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info("wavfile={}".format(args.wav_scp))
    logging.info("processingnum={}".format(args.nj))
    wpe_manager(
        wav_scp=args.wav_scp,
        processing_num=args.nj,
        audio_dir=args.audio_dir,
        output_dir=args.output_dir,
        channel=int(args.channel)
    )
egs/alimeeting/modular_sa_asr/local/segment_to_lab.py
New file
@@ -0,0 +1,58 @@
import argparse
import os
def read_segments_file(segments_file):
    utt2segments = dict()
    with open(segments_file, "r") as fr:
        lines = fr.readlines()
        for line in lines:
            parts = line.strip().split()
            segment_utt_id, utt_id, start, end = parts[0], parts[1], float(parts[2]), float(parts[3])
            if utt_id not in utt2segments:
                utt2segments[utt_id] = []
            utt2segments[utt_id].append((segment_utt_id, start, end))
    return utt2segments
def write_label(label_file, label_list):
    with open(label_file, "w") as fw:
        for (start, end) in label_list:
            fw.write(f"{start} {end} sp\n")
        fw.flush()
def write_label_scp_file(label_scp_file, label_scp: dict):
    with open(label_scp_file, "w") as fw:
        for (utt_id, label_path) in label_scp.items():
            fw.write(f"{utt_id} {label_path}\n")
        fw.flush()
def main(args):
    input_segments = args.input_segments
    label_path = args.label_path
    output_label_scp_file = args.output_label_scp_file
    utt2segments = read_segments_file(input_segments)
    print(f"Collect {len(utt2segments)} utt2segments in file {input_segments}")
    result_label_scp = dict()
    for utt_id in utt2segments.keys():
        segment_list = utt2segments[utt_id]
        cur_label_path = os.path.join(label_path, f"{utt_id}.lab")
        write_label(cur_label_path, label_list=[(i1, i2) for (_, i1, i2) in segment_list])
        result_label_scp[utt_id] = cur_label_path
    write_label_scp_file(output_label_scp_file, result_label_scp)
    print(f"Write {len(result_label_scp)} labels")
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Make the lab file for segments")
    parser.add_argument("--input_segments", required=True, help="The input segments file")
    parser.add_argument("--label_path", required=True, help="The label_path to save file.lab")
    parser.add_argument("--output_label_scp_file", required=True, help="The output label.scp file")
    args = parser.parse_args()
    main(args)
egs/alimeeting/modular_sa_asr/local/segmentation/detect_speech_activity.sh
New file
@@ -0,0 +1,225 @@
#!/bin/bash
# Copyright 2016-17  Vimal Manohar
#              2017  Nagendra Kumar Goel
# Apache 2.0.
# This script does nnet3-based speech activity detection given an input
# kaldi data directory and outputs a segmented kaldi data directory.
# This script can also do music detection and other similar segmentation
# using appropriate options such as --output-name output-music.
set -e
set -o pipefail
set -u
if [ -f ./path.sh ]; then . ./path.sh; fi
#export PATH=/usr/local/cuda-10.0/bin:$PATH
#export LD_LIBRARY_PATH=/usr/local/cuda-10.0/lib64:$LD_LIBRARY_PATH
#echo $PATH
#echo $LD_LIBRARY_PATH
affix=  # Affix for the segmentation
nj=32
cmd=run.pl
stage=-1
# Feature options (Must match training)
mfcc_config=conf/mfcc_hires.conf
feat_affix=   # Affix for the type of feature used
output_name=output   # The output node in the network
sad_name=sad    # Base name for the directory storing the computed loglikes
                # Can be music for music detection
segmentation_name=segmentation  # Base name for the directory doing segmentation
                                # Can be segmentation_music for music detection
# SAD network config
iter=final  # Model iteration to use
# Contexts must ideally match training for LSTM models, but
# may not necessarily for stats components
extra_left_context=0  # Set to some large value, typically 40 for LSTM (must match training)
extra_right_context=0
extra_left_context_initial=-1
extra_right_context_final=-1
frames_per_chunk=150
# Decoding options
graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0"
acwt=0.3
# These <from>_in_<to>_weight represent the fraction of <from> probability
# to transfer to <to> class.
# e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3
transform_probs_opts=""
# Postprocessing options
segment_padding=0.2   # Duration (in seconds) of padding added to segments
min_segment_dur=0   # Minimum duration (in seconds) required for a segment to be included
                    # This is before any padding. Segments shorter than this duration will be removed.
                    # This is an alternative to --min-speech-duration above.
merge_consecutive_max_dur=0   # Merge consecutive segments as long as the merged segment is no longer than this many
                              # seconds. The segments are only merged if their boundaries are touching.
                              # This is after padding by --segment-padding seconds.
                              # 0 means do not merge. Use 'inf' to not limit the duration.
echo $*
. utils/parse_options.sh
if [ $# -ne 5 ]; then
  echo "This script does nnet3-based speech activity detection given an input kaldi "
  echo "data directory and outputs an output kaldi data directory."
  echo "See script for details of the options to be supplied."
  echo "Usage: $0 <src-data-dir> <sad-nnet-dir> <mfcc-dir> <work-dir> <out-data-dir>"
  echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev exp/nnet3_sad_snr/nnet_tdnn_j_n4 \\"
  echo "    mfcc_hires exp/segmentation_sad_snr/nnet_tdnn_j_n4 data/ami_sdm1_dev"
  echo ""
  echo "Options: "
  echo "  --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
  echo "  --nj <num-job>                                 # number of parallel jobs to run."
  echo "  --stage <stage>                                # stage to do partial re-run from."
  echo "  --convert-data-dir-to-whole <true|false>    # If true, the input data directory is "
  echo "                                              # first converted to whole data directory (i.e. whole recordings) "
  echo "                                              # and segmentation is done on that."
  echo "                                              # If false, then the original segments are "
  echo "                                              # retained and they are split into sub-segments."
  echo "  --output-name <name>    # The output node in the network"
  echo "  --extra-left-context  <context|0>   # Set to some large value, typically 40 for LSTM (must match training)"
  echo "  --extra-right-context  <context|0>   # For BLSTM or statistics pooling"
  exit 1
fi
src_data_dir=$1   # The input data directory that needs to be segmented.
                  # If convert_data_dir_to_whole is true, any segments in that will be ignored.
sad_nnet_dir=$2   # The SAD neural network
mfcc_dir=$3       # The directory to store the features
dir=$4            # Work directory
data_dir=$5       # The output data directory will be ${data_dir}_seg
affix=${affix:+_$affix}
feat_affix=${feat_affix:+_$feat_affix}
data_id=`basename $data_dir`
sad_dir=${dir}/${sad_name}${affix}_${data_id}${feat_affix}
seg_dir=${dir}/${segmentation_name}${affix}_${data_id}${feat_affix}
# test_data_dir=data/${data_id}${feat_affix}
test_data_dir=${src_data_dir}
###############################################################################
## Forward pass through the network network and dump the log-likelihoods.
###############################################################################
frame_subsampling_factor=1
if [ -f $sad_nnet_dir/frame_subsampling_factor ]; then
  frame_subsampling_factor=$(cat $sad_nnet_dir/frame_subsampling_factor)
fi
mkdir -p $dir
if [ $stage -le 1 ]; then
  if [ "$(readlink -f $sad_nnet_dir)" != "$(readlink -f $dir)" ]; then
    cp $sad_nnet_dir/cmvn_opts $dir || exit 1
  fi
  ########################################################################
  ## Initialize neural network for decoding using the output $output_name
  ########################################################################
  if [ ! -z "$output_name" ] && [ "$output_name" != output ]; then
    $cmd $dir/log/get_nnet_${output_name}.log \
      nnet3-copy --edits="rename-node old-name=$output_name new-name=output" \
      $sad_nnet_dir/$iter.raw $dir/${iter}_${output_name}.raw || exit 1
    iter=${iter}_${output_name}
  else
    if ! diff $sad_nnet_dir/$iter.raw $dir/$iter.raw; then
      cp $sad_nnet_dir/$iter.raw $dir/
    fi
  fi
  echo ${test_data_dir}
  steps/nnet3/compute_output.sh --nj $nj --cmd "$cmd" \
    --iter ${iter} \
    --extra-left-context $extra_left_context \
    --extra-right-context $extra_right_context \
    --extra-left-context-initial $extra_left_context_initial \
    --extra-right-context-final $extra_right_context_final \
    --frames-per-chunk $frames_per_chunk --apply-exp true \
    --frame-subsampling-factor $frame_subsampling_factor \
    ${test_data_dir} $dir $sad_dir || exit 1
fi
###############################################################################
## Prepare FST we search to make speech/silence decisions.
###############################################################################
utils/data/get_utt2dur.sh --nj $nj --cmd "$cmd" $test_data_dir || exit 1
frame_shift=$(utils/data/get_frame_shift.sh $test_data_dir) || exit 1
graph_dir=${dir}/graph_${output_name}
if [ $stage -le 2 ]; then
  mkdir -p $graph_dir
  # 1 for silence and 2 for speech
  cat <<EOF > $graph_dir/words.txt
<eps> 0
silence 1
speech 2
EOF
  $cmd $graph_dir/log/make_graph.log \
    steps/segmentation/internal/prepare_sad_graph.py $graph_opts \
      --frame-shift=$(perl -e "print $frame_shift * $frame_subsampling_factor") - \| \
    fstcompile --isymbols=$graph_dir/words.txt --osymbols=$graph_dir/words.txt '>' \
      $graph_dir/HCLG.fst
fi
###############################################################################
## Do Viterbi decoding to create per-frame alignments.
###############################################################################
post_vec=$sad_nnet_dir/post_${output_name}.vec
if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then
  if [ ! -f $sad_nnet_dir/post_${output_name}.txt ]; then
    echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. "
    echo "Re-run the corresponding stage in the training script possibly "
    echo "with --compute-average-posteriors=true or compute the priors "
    echo "from the training labels"
    exit 1
  else
    post_vec=$sad_nnet_dir/post_${output_name}.txt
  fi
fi
mkdir -p $seg_dir
if [ $stage -le 3 ]; then
  steps/segmentation/internal/get_transform_probs_mat.py \
    --priors="$post_vec" $transform_probs_opts > $seg_dir/transform_probs.mat
  steps/segmentation/decode_sad.sh --acwt $acwt --cmd "$cmd" \
    --nj $nj \
    --transform "$seg_dir/transform_probs.mat" \
    $graph_dir $sad_dir $seg_dir
fi
###############################################################################
## Post-process segmentation to create kaldi data directory.
###############################################################################
if [ $stage -le 4 ]; then
  steps/segmentation/post_process_sad_to_segments.sh \
    --segment-padding $segment_padding --min-segment-dur $min_segment_dur \
    --merge-consecutive-max-dur $merge_consecutive_max_dur \
    --cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \
    ${test_data_dir} ${seg_dir} ${seg_dir}
fi
if [ $stage -le 5 ]; then
  utils/data/subsegment_data_dir.sh ${test_data_dir} ${seg_dir}/segments \
    ${data_dir}_seg
fi
echo "$0: Created output segmented kaldi data directory in ${data_dir}_seg"
exit 0
egs/alimeeting/modular_sa_asr/local/test_gss.py
New file
@@ -0,0 +1,141 @@
import io
import functools
import logging
# import soundfile as sf
import numpy as np
import matplotlib
import matplotlib.pylab as plt
# from IPython.display import display, Audio
from nara_wpe.utils import stft, istft
from pb_bss.distribution import CACGMMTrainer
from pb_bss.evaluation import InputMetrics, OutputMetrics
from dataclasses import dataclass
# from beamforming_wrapper import beamform_mvdr_souden_from_masks
from pb_chime5.utils.numpy_utils import segment_axis_v2
from textgrid_processor import read_textgrid_from_file
def get_time_activity(dur_list, wavlen, sr):
    time_activity = [False] * wavlen
    for dur in dur_list:
        xmax = int(dur[1] * sr)
        xmin = int(dur[0] * sr)
        if xmax > wavlen:
            continue
        for i in range(xmin, xmax):
            time_activity[i] = True
    logging.info("Num of actived samples {}".format(time_activity.count(True)))
    return time_activity
def get_frequency_activity(
    time_activity,
    stft_window_length,
    stft_shift,
    stft_fading=True,
    stft_pad=True,
):
    time_activity = np.asarray(time_activity)
    if stft_fading:
        pad_width = np.array([(0, 0)] * time_activity.ndim)
        pad_width[-1, :] = stft_window_length - stft_shift  # Consider fading
        time_activity = np.pad(time_activity, pad_width, mode="constant")
    return segment_axis_v2(
        time_activity,
        length=stft_window_length,
        shift=stft_shift,
        end="pad" if stft_pad else "cut",
    ).any(axis=-1)
@dataclass
class Beamformer:
    type: str
    postfilter: str
    def __call__(self, Obs, target_mask, distortion_mask, debug=False):
        bf = self.type
        if bf == "mvdrSouden_ban":
            from pb_chime5.speech_enhancement.beamforming_wrapper import (
                beamform_mvdr_souden_from_masks,
            )
            X_hat = beamform_mvdr_souden_from_masks(
                Y=Obs,
                X_mask=target_mask,
                N_mask=distortion_mask,
                ban=True,
            )
        elif bf == "ch0":
            X_hat = Obs[0]
        elif bf == "sum":
            X_hat = np.sum(Obs, axis=0)
        else:
            raise NotImplementedError(bf)
        if self.postfilter is None:
            pass
        elif self.postfilter == "mask_mul":
            X_hat = X_hat * target_mask
        else:
            raise NotImplementedError(self.postfilter)
        return X_hat
@dataclass
class GSS:
    iterations: int = 20
    iterations_post: int = 0
    verbose: bool = True
    # use_pinv: bool = False
    # stable: bool = True
    def __call__(self, Obs, acitivity_freq=None, debug=False):
        initialization = np.asarray(acitivity_freq, dtype=np.float64)
        initialization = np.where(initialization == 0, 1e-10, initialization)
        initialization = initialization / np.sum(initialization, keepdims=True, axis=0)
        initialization = np.repeat(initialization[None, ...], 257, axis=0)
        source_active_mask = np.asarray(acitivity_freq, dtype=bool)
        source_active_mask = np.repeat(source_active_mask[None, ...], 257, axis=0)
        cacGMM = CACGMMTrainer()
        if debug:
            learned = []
        all_affiliations = []
        F = Obs.shape[-1]
        T = Obs.T.shape[-2]
        for f in range(F):
            if self.verbose:
                if f % 50 == 0:
                    logging.info(f"{f}/{F}")
            # T: Consider end of signal.
            # This should not be nessesary, but activity is for inear and not for
            # array.
            cur = cacGMM.fit(
                y=Obs.T[f, ...],
                initialization=initialization[f, ..., :T],
                iterations=self.iterations,
                source_activity_mask=source_active_mask[f, ..., :T],
            )
            affiliation = cur.predict(
                Obs.T[f, ...],
                source_activity_mask=source_active_mask[f, ..., :T],
            )
            all_affiliations.append(affiliation)
        posterior = np.array(all_affiliations).transpose(1, 2, 0)
        return posterior
egs/alimeeting/modular_sa_asr/local/textgrid_processor.py
New file
@@ -0,0 +1,316 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import codecs
class TextGrid(object):
    def __init__(
        self,
        file_type="",
        object_class="",
        xmin=0.0,
        xmax=0.0,
        tiers_status="",
        tiers=[],
    ):
        self.file_type = file_type
        self.object_class = object_class
        self.xmin = xmin
        self.xmax = xmax
        self.tiers_status = tiers_status
        self.tiers = tiers
        if self.xmax < self.xmin:
            raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
    def cutoff(self, xstart=None, xend=None):
        if xstart is None:
            xstart = self.xmin
        if xend is None:
            xend = self.xmax
        if xend < xstart:
            raise ValueError("xend ({}) < xstart ({})".format(xend, xstart))
        new_xmax = xend - xstart + self.xmin
        new_xmin = self.xmin
        new_tiers = []
        for tier in self.tiers:
            new_tiers.append(tier.cutoff(xstart=xstart, xend=xend))
        return TextGrid(
            file_type=self.file_type,
            object_class=self.object_class,
            xmin=new_xmin,
            xmax=new_xmax,
            tiers_status=self.tiers_status,
            tiers=new_tiers,
        )
class Tier(object):
    def __init__(self, tier_class="", name="", xmin=0.0, xmax=0.0, intervals=[]):
        self.tier_class = tier_class
        self.name = name
        self.xmin = xmin
        self.xmax = xmax
        self.intervals = intervals
        if self.xmax < self.xmin:
            raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
    def cutoff(self, xstart=None, xend=None):
        if xstart is None:
            xstart = self.xmin
        if xend is None:
            xend = self.xmax
        if xend < xstart:
            raise ValueError("xend ({}) < xstart ({})".format(xend, xstart))
        bias = xstart - self.xmin
        new_xmax = xend - bias
        new_xmin = self.xmin
        new_intervals = []
        for interval in self.intervals:
            if interval.xmax <= xstart or interval.xmin >= xend:
                pass
            elif interval.xmin < xstart:
                new_intervals.append(
                    Interval(
                        xmin=new_xmin, xmax=interval.xmax - bias, text=interval.text
                    )
                )
            elif interval.xmax > xend:
                new_intervals.append(
                    Interval(
                        xmin=interval.xmin - bias, xmax=new_xmax, text=interval.text
                    )
                )
            else:
                new_intervals.append(
                    Interval(
                        xmin=interval.xmin - bias,
                        xmax=interval.xmax - bias,
                        text=interval.text,
                    )
                )
        return Tier(
            tier_class=self.tier_class,
            name=self.name,
            xmin=new_xmin,
            xmax=new_xmax,
            intervals=new_intervals,
        )
class Interval(object):
    def __init__(self, xmin=0.0, xmax=0.0, text=""):
        self.xmin = xmin
        self.xmax = xmax
        self.text = text
        if self.xmax < self.xmin:
            raise ValueError("xmax ({}) < xmin ({})".format(self.xmax, self.xmin))
def read_textgrid_from_file(filepath):
    with codecs.open(filepath, "r", encoding="utf-8") as handle:
        lines = handle.readlines()
    if lines[-1] == "\r\n":
        lines = lines[:-1]
    assert "File type" in lines[0], "error line 0, {}".format(lines[0])
    file_type = (
        lines[0]
        .split("=")[1]
        .replace(" ", "")
        .replace('"', "")
        .replace("\r", "")
        .replace("\n", "")
    )
    assert "Object class" in lines[1], "error line 1, {}".format(lines[1])
    object_class = (
        lines[1]
        .split("=")[1]
        .replace(" ", "")
        .replace('"', "")
        .replace("\r", "")
        .replace("\n", "")
    )
    assert lines[2] == "\r\n", "error line 2, {}".format(lines[2])
    assert "xmin" in lines[3], "error line 3, {}".format(lines[3])
    xmin = float(
        lines[3].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    assert "xmax" in lines[4], "error line 4, {}".format(lines[4])
    xmax = float(
        lines[4].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    assert "tiers?" in lines[5], "error line 5, {}".format(lines[5])
    tiers_status = (
        lines[5].split("?")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    assert "size" in lines[6], "error line 6, {}".format(lines[6])
    size = int(
        lines[6].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    assert lines[7] == "item []:\r\n", "error line 7, {}".format(lines[7])
    tier_start = []
    for item_idx in range(size):
        tier_start.append(lines.index(" " * 4 + "item [{}]:\r\n".format(item_idx + 1)))
    tier_end = tier_start[1:] + [len(lines)]
    tiers = []
    for tier_idx in range(size):
        tiers.append(
            read_tier_from_lines(
                tier_lines=lines[tier_start[tier_idx] + 1 : tier_end[tier_idx]]
            )
        )
    return TextGrid(
        file_type=file_type,
        object_class=object_class,
        xmin=xmin,
        xmax=xmax,
        tiers_status=tiers_status,
        tiers=tiers,
    )
def read_tier_from_lines(tier_lines):
    assert "class" in tier_lines[0], "error line 0, {}".format(tier_lines[0])
    tier_class = (
        tier_lines[0]
        .split("=")[1]
        .replace(" ", "")
        .replace('"', "")
        .replace("\r", "")
        .replace("\n", "")
    )
    assert "name" in tier_lines[1], "error line 1, {}".format(tier_lines[1])
    name = (
        tier_lines[1]
        .split("=")[1]
        .replace(" ", "")
        .replace('"', "")
        .replace("\r", "")
        .replace("\n", "")
    )
    assert "xmin" in tier_lines[2], "error line 2, {}".format(tier_lines[2])
    xmin = float(
        tier_lines[2].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    assert "xmax" in tier_lines[3], "error line 3, {}".format(tier_lines[3])
    xmax = float(
        tier_lines[3].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    assert "intervals: size" in tier_lines[4], "error line 4, {}".format(tier_lines[4])
    intervals_num = int(
        tier_lines[4].split("=")[1].replace(" ", "").replace("\r", "").replace("\n", "")
    )
    # handle unformatted case
    # R12_S203204205_C09_I1_Near_203.TextGrid
    # R12_S203204205_C09_I1_Near_205.TextGrid
    if tier_lines[-1] == "\n":
        tier_lines = tier_lines[:-1]
    if len(tier_lines[5:]) == intervals_num * 5:
        intervals = []
        for intervals_idx in range(intervals_num):
            assert tier_lines[
                5 + 5 * intervals_idx + 0
            ] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
            assert tier_lines[
                5 + 5 * intervals_idx + 1
            ] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
            intervals.append(
                read_interval_from_lines(
                    interval_lines=tier_lines[
                        7 + 5 * intervals_idx : 10 + 5 * intervals_idx
                    ]
                )
            )
    elif len(tier_lines[5:]) == intervals_num * 4:
        # handle unformatted case
        # R12_S203204205_C09_I1_Near_203.TextGrid
        # R12_S203204205_C09_I1_Near_204.TextGrid
        # R12_S203204205_C09_I1_Near_205.TextGrid
        intervals = []
        for intervals_idx in range(intervals_num):
            assert tier_lines[
                5 + 4 * intervals_idx + 0
            ] == " " * 8 + "intervals [{}]:\r\n".format(intervals_idx + 1)
            intervals.append(
                read_interval_from_lines(
                    interval_lines=tier_lines[
                        6 + 4 * intervals_idx : 9 + 4 * intervals_idx
                    ]
                )
            )
    else:
        import pdb
        pdb.set_trace()
        raise ValueError(
            "error lines {} % {} != 0".format(len(tier_lines[5:]), intervals_num)
        )
    return Tier(
        tier_class=tier_class, name=name, xmin=xmin, xmax=xmax, intervals=intervals
    )
def read_interval_from_lines(interval_lines):
    assert len(interval_lines) == 3, "error lines"
    assert "xmin" in interval_lines[0], "error line 0, {}".format(interval_lines[0])
    xmin = float(
        interval_lines[0]
        .split("=")[1]
        .replace(" ", "")
        .replace("\r", "")
        .replace("\n", "")
    )
    assert "xmax" in interval_lines[1], "error line 1, {}".format(interval_lines[1])
    xmax = float(
        interval_lines[1]
        .split("=")[1]
        .replace(" ", "")
        .replace("\r", "")
        .replace("\n", "")
    )
    assert "text" in interval_lines[2], "error line 2, {}".format(interval_lines[2])
    text = (
        interval_lines[2]
        .split("=")[1]
        .replace(" ", "")
        .replace('"', "")
        .replace("\r", "")
        .replace("\n", "")
    )
    return Interval(xmin=xmin, xmax=xmax, text=text)
egs/alimeeting/modular_sa_asr/path.sh
New file
@@ -0,0 +1,14 @@
export FUNASR_DIR=$PWD/../../..
export KALDI_ROOT=/Your_Kaldi_root
export DATA_SOURCE=/Your_data_path
export DATA_NAME=Test_2023_Ali_far
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/tools/config/common_path.sh
export LC_ALL=C
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:./utils:$FUNASR_DIR:$PATH
export PYTHONPATH=$FUNASR_DIR:$PYTHONPATH
egs/alimeeting/modular_sa_asr/run_asr.sh
New file
@@ -0,0 +1,152 @@
#!/usr/bin/env bash
. ./path.sh || exit 1;
# machines configuration
CUDA_VISIBLE_DEVICES="4,5,6,7"
gpu_num=4
count=1
gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
finetune=true
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=2
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="data" #feature output dictionary
exp_dir="."
lang=zh
token_type=char
type=sound
scp=wav.scp
speed_perturb="1.0"
stage=0
stop_stage=1
# feature configuration
feats_dim=80
nj=64
# exp tag
tag="finetune"
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=Train_Ali_far_wpegss
valid_set=Test_Ali_far_wpegss
test_sets="${DATA_NAME}_wpegss"
asr_config=conf/train_paraformer.yaml
model_dir="$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
pretrain_model_dir=./speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
inference_config=$pretrain_model_dir/decoding.yaml
token_list=$pretrain_model_dir/tokens.txt
# you can set gpu num for decoding here
gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
    inference_nj=$[${ngpu}*${njob}]
    _ngpu=1
else
    inference_nj=$njob
    _ngpu=0
fi
if ${finetune}; then
    inference_asr_model=./checkpoint/valid.acc.ave.pb
    finetune_tag="_finetune"
else
    inference_asr_model=$pretrain_model_dir/model.pb
    finetune_tag=""
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    if [ -L ./utils ]; then
        unlink ./utils
        ln -s ../../aishell/transformer/utils
    else
        ln -s ../../aishell/transformer/utils
    fi
fi
# Download Model
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Download Model"
    if [ ! -d $pretrain_model_dir ]; then
        git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git
    fi
fi
# ASR Training Stage
world_size=$gpu_num  # run on one machine
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: ASR Training"
    python  -m torch.distributed.launch \
     --nproc_per_node $gpu_num local/finetune.py
fi
# Testing Stage
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: Inference"
    for dset in ${test_sets}; do
        _dir="$pretrain_model_dir/decode_${dset}${finetune_tag}"
        _logdir="${_dir}/logdir"
        if [ -d ${_dir} ]; then
            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
            exit 0
        fi
        mkdir -p "${_logdir}"
        _data="./data/${dset}"
        key_file=${_data}/${scp}
        num_scp_file="$(<${key_file} wc -l)"
        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
        split_scps=
        for n in $(seq "${_nj}"); do
            split_scps+=" ${_logdir}/keys.${n}.scp"
        done
        # shellcheck disable=SC2086
        utils/split_scp.pl "${key_file}" ${split_scps}
        _opts=
        if [ -n "${inference_config}" ]; then
            _opts+="--config ${inference_config} "
        fi
        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
            python -m funasr.bin.asr_inference_launch \
                --batch_size 1 \
                --ngpu "${_ngpu}" \
                --njob ${njob} \
                --gpuid_list ${gpuid_list} \
                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
                --cmvn_file $pretrain_model_dir/am.mvn \
                --key_file "${_logdir}"/keys.JOB.scp \
                --asr_train_config $pretrain_model_dir/config.yaml \
                --asr_model_file $inference_asr_model \
                --output_dir "${_logdir}"/output.JOB \
                --mode paraformer \
                ${_opts}
        for f in token token_int score text; do
            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
                for i in $(seq "${_nj}"); do
                    cat "${_logdir}/output.${i}/1best_recog/${f}"
                done | sort -k1 >"${_dir}/${f}"
            fi
        done
        python local/merge_spk_text.py ${_dir}/text ${_data}/utt2spk
        python local/compute_cpcer.py ${_data}/text_merge ${_dir}/text_merge
        echo "cpCER is saved at ${_dir}/text_cpcer"
    done
fi
egs/alimeeting/modular_sa_asr/run_diar.sh
New file
@@ -0,0 +1,233 @@
#!/usr/bin/env bash
set -e
set -o pipefail
. path.sh || exit 1
train_cmd=utils/run.pl
# data path
data_source_dir=$DATA_SOURCE
textgrid_dir=$data_source_dir/textgrid_dir/
wav_dir=$data_source_dir/audio_dir/
# work path
work_dir=./data/${DATA_NAME}_sc/
sad_dir=$work_dir/sad_part/
sad_work_dir=$sad_dir/exp/
sad_result_dir=$sad_dir/sad
dia_dir=$work_dir/dia_part/
dia_vad_dir=$dia_dir/vad/
dia_rttm_dir=$dia_dir/rttm/
dia_emb_dir=$dia_dir/embedding/
dia_rtt_label_dir=$dia_dir/label_rttm/
dia_result_dir=$dia_dir/result_DER/
sond_work_dir=./data/${DATA_NAME}_sond/
asr_work_dir=./data/${DATA_NAME}_wpegss/org/
mkdir -p $work_dir || exit 1;
mkdir -p $sad_dir || exit 1;
mkdir -p $sad_work_dir || exit 1;
mkdir -p $sad_result_dir || exit 1;
mkdir -p $dia_dir || exit 1;
mkdir -p $dia_vad_dir || exit 1;
mkdir -p $dia_rttm_dir || exit 1;
mkdir -p $dia_emb_dir || exit 1;
mkdir -p $dia_rtt_label_dir || exit 1;
mkdir -p $dia_result_dir || exit 1;
mkdir -p $sond_work_dir || exit 1;
mkdir -p $asr_work_dir || exit 1;
stage=0
stop_stage=9
nj=4
sm_size=83
if [ $stage -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    # Check the installtion of kaldi
    if [ -L ./steps ]; then
        unlink ./steps
    else
        ln -s $KALDI_ROOT/egs/wsj/s5/steps || { echo "You must install kaldi first, and set the KALDI_ROOT in path.sh" && exit 1; }
    fi
    if [ -L ./utils ]; then
        unlink ./utils
    else
        ln -s $KALDI_ROOT/egs/wsj/s5/utils || { echo "You must install kaldi first, and set the KALDI_ROOT in path.sh" && exit 1; }
    fi
fi
if [ $stage -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    # Prepare the AliMeeting data
    echo "Prepare Alimeeting data"
    find $wav_dir -name "*\.wav" > $work_dir/wavlist
    sort  $work_dir/wavlist > $work_dir/tmp
    cp $work_dir/tmp $work_dir/wavlist
    awk -F '/' '{print $NF}' $work_dir/wavlist | awk -F '.' '{print $1}' > $work_dir/uttid
    paste -d " " $work_dir/uttid $work_dir/wavlist > $work_dir/wav.scp
    paste -d " " $work_dir/uttid $work_dir/uttid > $work_dir/utt2spk
    cp $work_dir/utt2spk $work_dir/spk2utt
    cp $work_dir/uttid $work_dir/text
    sad_feat=$sad_dir/feat/mfcc
    cp $work_dir/wav.scp $sad_dir
    cp $work_dir/utt2spk $sad_dir
    cp $work_dir/spk2utt $sad_dir
    cp $work_dir/text    $sad_dir
    utils/fix_data_dir.sh $sad_dir
    ## first we extract the feature for sad model
    steps/make_mfcc.sh --nj $nj --cmd "$train_cmd" \
        --mfcc-config conf/mfcc_hires.conf \
        $sad_dir $sad_dir/make_mfcc $sad_feat
fi
if [ $stage -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    # Do Speech Activity Detectation
    echo "Do SAD"
    ./utils/split_data.sh $sad_dir $nj
    ## do the segmentations
    local/segmentation/detect_speech_activity.sh --nj $nj --stage 0 \
        --cmd "$train_cmd" $sad_dir exp/segmentation_1a/tdnn_stats_sad_1a/ \
        $sad_dir/feat/mfcc $sad_work_dir $sad_result_dir
fi
if [ $stage -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "Do Speaker Embedding Extractor"
    cp $work_dir/wav.scp $dia_dir
    python local/segment_to_lab.py --input_segments $sad_dir/sad_seg/segments \
                                     --label_path $dia_vad_dir \
                                     --output_label_scp_file $dia_dir/label.scp ||exit 1;
    ./utils/split_data.sh $work_dir $nj
    ${train_cmd} JOB=1:${nj} $dia_dir/exp/extract_embedding.JOB.log \
    python VBx/predict.py --in-file-list $work_dir/split${nj}/JOB/text \
                          --in-lab-dir $dia_dir/vad \
                          --in-wav-dir $wav_dir \
                          --out-ark-fn $dia_emb_dir/embedding_out.JOB.ark \
                          --out-seg-fn $dia_emb_dir/embedding_out.JOB.seg \
                          --weights VBx/models/ResNet101_16kHz/nnet/final.onnx \
                          --backend onnx
    echo "success"
fi
if [ $stage -le 4 ] && [ ${stop_stage} -ge 4 ]; then
    # The Speaker Embedding Cluster
    echo "Do the Speaker Embedding Cluster"
    # The meeting data is long so that the cluster is a little bit slow
    ${train_cmd} JOB=1:${nj} $dia_dir/exp/cluster.JOB.log \
     python VBx/vbhmm.py --init AHC+VB \
                         --out-rttm-dir $dia_rttm_dir \
                         --xvec-ark-file $dia_emb_dir/embedding_out.JOB.ark \
                         --segments-file $dia_emb_dir/embedding_out.JOB.seg \
                         --xvec-transform VBx/models/ResNet101_16kHz/transform.h5 \
                         --plda-file VBx/models/ResNet101_16kHz/plda \
                         --threshold 0.14 \
                         --lda-dim 128 \
                         --Fa 0.3 \
                         --Fb 17 \
                         --loopP 0.99
fi
if [ $stage -le 5 ] && [ ${stop_stage} -ge 5 ]; then
    echo "Process textgrid to obtain rttm label"
    find -L $textgrid_dir -iname "*.TextGrid" >  $work_dir/textgrid.flist
    sort  $work_dir/textgrid.flist  > $work_dir/tmp
    cp $work_dir/tmp $work_dir/textgrid.flist
    paste $work_dir/uttid $work_dir/textgrid.flist > $work_dir/uttid_textgrid.flist
    while read text_file
    do
        text_grid=`echo $text_file | awk '{print $1}'`
        text_grid_path=`echo $text_file | awk '{print $2}'`
        python local/make_textgrid_rttm.py --input_textgrid_file $text_grid_path \
                                           --uttid $text_grid \
                                           --output_rttm_file $dia_rtt_label_dir/${text_grid}.rttm
    done < $work_dir/uttid_textgrid.flist
    if [ -f "$dia_rtt_label_dir/all.rttm" ]; then
        rm -f $dia_rtt_label_dir/all.rttm
    fi
    cat $dia_rtt_label_dir/*.rttm > $dia_rtt_label_dir/all.rttm
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
    echo "Get VBx DER result"
    find $dia_rtt_label_dir  -name "*.rttm" > $dia_rtt_label_dir/ref.scp
    find $dia_rttm_dir  -name "*.rttm" > $dia_rttm_dir/sys.scp
    if [ -f "$dia_rttm_dir/all.rttm" ]; then
        rm -f $dia_rttm_dir/all.rttm
    fi
    cat $dia_rttm_dir/*.rttm > $dia_rttm_dir/all.rttm
    collar_set="0 0.25"
    python local/meeting_speaker_number_process.py  --path=$work_dir \
        --label_path=$dia_rtt_label_dir   --predict_path=$dia_rttm_dir
    speaker_number="2 3 4"
    for weight_collar in $collar_set;
    do
        # all meeting
        python dscore/score.py --collar $weight_collar  \
            -R $dia_rtt_label_dir/ref.scp  -S $dia_rttm_dir/sys.scp > $dia_result_dir/speaker_all_DER_overlaps_${weight_collar}.log
        # 2,3,4 speaker meeting
        for speaker_count in $speaker_number;
        do
            python dscore/score.py --collar $weight_collar  \
                -R $dia_rtt_label_dir/speaker${speaker_count}_id  -S $dia_rttm_dir/speaker${speaker_count}_id > $dia_result_dir/speaker_${speaker_count}_DER_overlaps_${weight_collar}.log
        done
    done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
    echo "Downloading Pre-trained model..."
    mkdir ./SOND
    cd ./SOND
    git clone https://www.modelscope.cn/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch.git
    git clone https://www.modelscope.cn/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch.git
    ln -s speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.pth ./sv.pb
    cp speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/sv.yaml ./sv.yaml
    ln -s speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.pth ./sond.pb
    cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond_fbank.yaml ./sond_fbank.yaml
    cp speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/sond.yaml ./sond.yaml
    cd ..
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
    echo "Prepare data for sond"
    cp $work_dir/wav.scp $sond_work_dir
    # convert rttm to segments
    python local/rttm2segments.py $dia_rttm_dir/all.rttm $sond_work_dir 0
    # remove the overlapped part
    python local/remove_overlap.py $sond_work_dir/segments $sond_work_dir/utt2spk \
     $sond_work_dir/segments_nooverlap $sond_work_dir/utt2spk_nooverlap 0.3
    # extract speaker profile from the filtered segments file
    python local/extract_profile_from_segments.py $sond_work_dir
    # segment data to 16s
    python local/resegment_data.py \
        $data_source_dir/segments \
        $data_source_dir/wav.scp \
        $sond_work_dir
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
    echo "Diarization with SOND"
    python local/infer_sond.py SOND/sond.yaml SOND/sond.pb $sond_work_dir $sond_work_dir/dia_outputs
    python local/convert_label_to_rttm.py \
        $sond_work_dir/dia_outputs/labels.txt \
        $sond_work_dir/map.scp \
        $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm \
        --ignore_len 10 --no_pbar --smooth_size ${sm_size} \
        --vote_prob 0.5 --n_spk 16
    python dscore/score.py \
        -r $dia_rtt_label_dir/all.rttm \
        -s $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm \
        --collar 0.25 &> $sond_work_dir/dia_outputs/dia_result
    # convert rttm to segments
    python local/rttm2segments.py $sond_work_dir/dia_outputs/prediction_sm_${sm_size}.rttm $asr_work_dir 1
fi
egs/alimeeting/modular_sa_asr/run_enh.sh
New file
@@ -0,0 +1,114 @@
#!/usr/bin/env bash
set -e
set -o pipefail
log() {
    local fname=${BASH_SOURCE[1]##*/}
    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
SECONDS=0
# general configuration
stage=1
stop_stage=3
nj=10
log "$0 $*"
. utils/parse_options.sh
. ./path.sh || exit 1
train_cmd=utils/run.pl
data_source_dir=$DATA_SOURCE
audio_dir=$data_source_dir/audio_dir
output_wpe_dir=$data_source_dir/wpe_audio_dir
output_gss_dir=$data_source_dir/gss_audio_dir
asr_data_path=./data/${DATA_NAME}_wpegss
channel=$1
log "Start Speech Enhancement."
if [ ! -L ./utils ]; then
    ln -s ./pb_chime5/pb_bss
fi
# WPE
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    log "stage 1: Start WPE."
    for ch in `seq ${channel}`; do
        mkdir -p  ${output_wpe_dir}_${ch}/log/
        # split wav.scp
        find $audio_dir/ -name "*.wav" > ${output_wpe_dir}_${ch}/wav.scp
        arr=""
        for i in `seq ${nj}`; do
            arr="$arr ${output_wpe_dir}_${ch}/log/wav.${i}.scp"
        done
        split_scp.pl ${output_wpe_dir}_${ch}/wav.scp $arr
        # do wpe
        for n in `seq ${nj}`; do
            cat <<-EOF >${output_wpe_dir}_${ch}/log/wpe.${n}.sh
python local/run_wpe.py \
    --wav-scp ${output_wpe_dir}_${ch}/log/wav.${n}.scp \
    --audio-dir ${audio_dir} \
    --output-dir ${output_wpe_dir}_${ch} \
    --ch $ch
EOF
        done
        chmod a+x ${output_wpe_dir}_${ch}/log/wpe.*.sh
        ${train_cmd} JOB=1:${nj} ${output_wpe_dir}_${ch}/log/wpe.JOB.log \
            ${output_wpe_dir}_${ch}/log/wpe.JOB.sh
    done
fi
# GSS
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    log "stage 2: Start GSS"
    if [ ! -d pb_chime5/ ]; then
        log "Please install pb_chime5 by local/install_pb_chime5.sh"
        exit 1
    fi
    mkdir -p $output_gss_dir/log
    # split wpe.scp
    for i in `seq ${channel}`; do
        find ${output_wpe_dir}_${i}/ -name "*.wav" > $output_gss_dir/tmp${i}
    done
    awk -F '/' '{print($NF)}' $output_gss_dir/tmp1 | cut -d "." -f1 > $output_gss_dir/tmp
    arr=""
    for i in `seq ${channel}`; do
        arr="$arr $output_gss_dir/tmp${i}"
    done
    paste -d " " $output_gss_dir/tmp $arr > $output_gss_dir/wpe.scp
    rm -f $output_gss_dir/tmp*
    arr=""
    for i in `seq ${nj}`; do
        arr="$arr $output_gss_dir/log/wpe.${i}.scp"
    done
    split_scp.pl $output_gss_dir/wpe.scp $arr
    # do gss
    for n in `seq ${nj}`; do
        cat <<-EOF >${output_gss_dir}/log/gss.${n}.sh
python local/run_gss.py \
    --wav-scp ${output_gss_dir}/log/wpe.${n}.scp \
    --segments $asr_data_path/org/segments \
    --output-dir ${output_gss_dir}
EOF
    done
    chmod a+x ${output_gss_dir}/log/gss.*.sh
    ${train_cmd} JOB=1:${nj} ${output_gss_dir}/log/gss.JOB.log \
        ${output_gss_dir}/log/gss.JOB.sh
fi
# Prepare data for ASR
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    log "stage 3: Preparing data for ASR"
    find $output_gss_dir -name "*.wav" > $asr_data_path/org/wav_list
    awk -F '/' '{print($NF)}' $asr_data_path/org/wav_list | sed 's/\.wav//g' > $asr_data_path/org/uttid
    paste -d " " $asr_data_path/org/uttid $asr_data_path/org/wav_list > $asr_data_path/org/wav.scp
    bash local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
        --audio-format wav --segments $asr_data_path/org/segments \
        "$asr_data_path/org/wav.scp" "$asr_data_path"
fi
log "End speech enhancement"
setup.py
@@ -19,8 +19,11 @@
        "soundfile>=0.12.1",
        "h5py>=2.10.0",
        "kaldiio>=2.17.0",
        "kaldi-io==0.9.8",
        "torch_complex",
        "nltk>=3.4.5",
        "onnxruntime"
        "numexpr"
        # ASR
        "sentencepiece",
        "jieba",
@@ -32,6 +35,8 @@
        "editdistance>=0.5.2",
        "tensorboard",
        "g2p",
        "nara_wpe",
        "Cython",
        # PAI
        "oss2",
        "edit-distance",
@@ -123,4 +128,4 @@
        "License :: OSI Approved :: Apache Software License",
        "Topic :: Software Development :: Libraries :: Python Modules",
    ],
)
)