From 92e8d4358a0c0ea323f00fa578382252c5b18732 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 19 七月 2023 10:35:12 +0800
Subject: [PATCH] update
---
egs/callhome/eend_ola/local/infer.py | 132 ++++++++++++++++++++++++++++++++++++++++++++
egs/callhome/eend_ola/local/random_mixture.py | 2
egs/callhome/eend_ola/local/run_prepare_shared_eda.sh | 2
3 files changed, 134 insertions(+), 2 deletions(-)
diff --git a/egs/callhome/eend_ola/local/infer.py b/egs/callhome/eend_ola/local/infer.py
new file mode 100644
index 0000000..78d160d
--- /dev/null
+++ b/egs/callhome/eend_ola/local/infer.py
@@ -0,0 +1,132 @@
+import argparse
+import os
+
+import numpy as np
+import soundfile as sf
+import torch
+import yaml
+from scipy.signal import medfilt
+
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
+from funasr.build_utils.build_model_from_file import build_model_from_file
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config_file",
+ type=str,
+ help="model config file",
+ )
+ parser.add_argument(
+ "--model_file",
+ type=str,
+ help="model path",
+ )
+ parser.add_argument(
+ "--output_rttm_file",
+ type=str,
+ help="output rttm path",
+ )
+ parser.add_argument(
+ "--wav_scp_file",
+ type=str,
+ default="wav.scp",
+ help="input data path",
+ )
+ parser.add_argument(
+ "--frame_shift",
+ type=int,
+ default=80,
+ help="frame shift",
+ )
+ parser.add_argument(
+ "--frame_size",
+ type=int,
+ default=200,
+ help="frame size",
+ )
+ parser.add_argument(
+ "--context_size",
+ type=int,
+ default=7,
+ help="context size",
+ )
+ parser.add_argument(
+ "--sampling_rate",
+ type=int,
+ default=10,
+ help="sampling rate",
+ )
+ parser.add_argument(
+ "--subsampling",
+ type=int,
+ default=10,
+ help="setting subsampling",
+ )
+ parser.add_argument(
+ "--attractor_threshold",
+ type=float,
+ default=0.5,
+ help="threshold for selecting attractors",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ )
+ args = parser.parse_args()
+
+ with open(args.config_file) as f:
+ configs = yaml.safe_load(f)
+ for k, v in configs.items():
+ if not hasattr(args, k):
+ setattr(args, k, v)
+
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ os.environ['PYTORCH_SEED'] = str(args.seed)
+
+ model, _ = build_model_from_file(config_file=args.config_file, model_file=args.model_file, task_name="diar",
+ device=args.device)
+ model.eval()
+
+ with open(args.wav_scp_file) as f:
+ wav_lines = [line.strip().split() for line in f.readlines()]
+ wav_items = {x[0]: x[1] for x in wav_lines}
+
+ print("Start inference")
+ with open(args.output_rttm_file, "w") as wf:
+ for wav_id in wav_items.keys():
+ print("Process wav: {}\n".format(wav_id))
+ data, rate = sf.read(wav_items[wav_id])
+ speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
+ speech = eend_ola_feature.transform(speech)
+ speech = eend_ola_feature.splice(speech, context_size=args.context_size)
+ speech = speech[::args.subsampling] # sampling
+ speech = torch.from_numpy(speech)
+
+ with torch.no_grad():
+ speech = speech.to(args.device)
+ ys, _, _, _ = model.estimate_sequential(
+ [speech],
+ n_speakers=None,
+ th=args.attractor_threshold,
+ shuffle=args.shuffle
+ )
+
+ a = ys[0].cpu().numpy()
+ a = medfilt(a, (11, 1))
+ rst = []
+ for spkr_id, frames in enumerate(a.T):
+ frames = np.pad(frames, (1, 1), 'constant')
+ changes, = np.where(np.diff(frames, axis=0) != 0)
+ fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
+ for s, e in zip(changes[::2], changes[1::2]):
+ st = s * args.frame_shift * args.subsampling / args.sampling_rate
+ dur = (e - s) * args.frame_shift * args.subsampling / args.sampling_rate
+ print(fmt.format(
+ wav_id,
+ st,
+ dur,
+ wav_id + "_" + str(spkr_id)), file=wf)
\ No newline at end of file
diff --git a/egs/callhome/eend_ola/local/random_mixture.py b/egs/callhome/eend_ola/local/random_mixture.py
index 0032ef9..05d7828 100755
--- a/egs/callhome/eend_ola/local/random_mixture.py
+++ b/egs/callhome/eend_ola/local/random_mixture.py
@@ -42,7 +42,7 @@
import argparse
import os
-from eend import kaldi_data
+from funasr.modules.eend_ola.utils import kaldi_data
import random
import numpy as np
import json
diff --git a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh
index 5431ba1..aec1ff2 100755
--- a/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh
+++ b/egs/callhome/eend_ola/local/run_prepare_shared_eda.sh
@@ -9,7 +9,7 @@
# - data/simu_${simu_outputs}
# simulation mixtures generated with various options
-stage=1
+stage=0
# Modify corpus directories
# - callhome_dir
--
Gitblit v1.9.1