嘉渊
2023-07-19 92e8d4358a0c0ea323f00fa578382252c5b18732
update
2个文件已修改
1个文件已添加
136 ■■■■■ 已修改文件
egs/callhome/eend_ola/local/infer.py 132 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/callhome/eend_ola/local/random_mixture.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/callhome/eend_ola/local/run_prepare_shared_eda.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/callhome/eend_ola/local/infer.py
New file
@@ -0,0 +1,132 @@
import argparse
import os
import numpy as np
import soundfile as sf
import torch
import yaml
from scipy.signal import medfilt
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from funasr.build_utils.build_model_from_file import build_model_from_file
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        type=str,
        help="model config file",
    )
    parser.add_argument(
        "--model_file",
        type=str,
        help="model path",
    )
    parser.add_argument(
        "--output_rttm_file",
        type=str,
        help="output rttm path",
    )
    parser.add_argument(
        "--wav_scp_file",
        type=str,
        default="wav.scp",
        help="input data path",
    )
    parser.add_argument(
        "--frame_shift",
        type=int,
        default=80,
        help="frame shift",
    )
    parser.add_argument(
        "--frame_size",
        type=int,
        default=200,
        help="frame size",
    )
    parser.add_argument(
        "--context_size",
        type=int,
        default=7,
        help="context size",
    )
    parser.add_argument(
        "--sampling_rate",
        type=int,
        default=10,
        help="sampling rate",
    )
    parser.add_argument(
        "--subsampling",
        type=int,
        default=10,
        help="setting subsampling",
    )
    parser.add_argument(
        "--attractor_threshold",
        type=float,
        default=0.5,
        help="threshold for selecting attractors",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
    )
    args = parser.parse_args()
    with open(args.config_file) as f:
        configs = yaml.safe_load(f)
        for k, v in configs.items():
            if not hasattr(args, k):
                setattr(args, k, v)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    os.environ['PYTORCH_SEED'] = str(args.seed)
    model, _ = build_model_from_file(config_file=args.config_file, model_file=args.model_file, task_name="diar",
                                  device=args.device)
    model.eval()
    with open(args.wav_scp_file) as f:
        wav_lines = [line.strip().split() for line in f.readlines()]
        wav_items = {x[0]: x[1] for x in wav_lines}
    print("Start inference")
    with open(args.output_rttm_file, "w") as wf:
        for wav_id in wav_items.keys():
            print("Process wav: {}\n".format(wav_id))
            data, rate = sf.read(wav_items[wav_id])
            speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
            speech = eend_ola_feature.transform(speech)
            speech = eend_ola_feature.splice(speech, context_size=args.context_size)
            speech = speech[::args.subsampling]  # sampling
            speech = torch.from_numpy(speech)
            with torch.no_grad():
                speech = speech.to(args.device)
                ys, _, _, _ = model.estimate_sequential(
                    [speech],
                    n_speakers=None,
                    th=args.attractor_threshold,
                    shuffle=args.shuffle
                )
            a = ys[0].cpu().numpy()
            a = medfilt(a, (11, 1))
            rst = []
            for spkr_id, frames in enumerate(a.T):
                frames = np.pad(frames, (1, 1), 'constant')
                changes, = np.where(np.diff(frames, axis=0) != 0)
                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
                for s, e in zip(changes[::2], changes[1::2]):
                    st = s * args.frame_shift * args.subsampling / args.sampling_rate
                    dur = (e - s) * args.frame_shift * args.subsampling / args.sampling_rate
                    print(fmt.format(
                        wav_id,
                        st,
                        dur,
                        wav_id + "_" + str(spkr_id)), file=wf)
egs/callhome/eend_ola/local/random_mixture.py
@@ -42,7 +42,7 @@
import argparse
import os
from eend import kaldi_data
from funasr.modules.eend_ola.utils import kaldi_data
import random
import numpy as np
import json
egs/callhome/eend_ola/local/run_prepare_shared_eda.sh
@@ -9,7 +9,7 @@
#   - data/simu_${simu_outputs}
#     simulation mixtures generated with various options
stage=1
stage=0
# Modify corpus directories
#  - callhome_dir