From 729ac44000a5c4aa23dcb1a3b80adc119a350b23 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 三月 2023 19:10:16 +0800
Subject: [PATCH] Merge pull request #235 from alibaba-damo-academy/dev_wjm
---
funasr/bin/eend_ola_inference.py | 24 +++++++++++++++++++-----
1 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index d65895f..0483278 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -16,6 +16,7 @@
import numpy as np
import torch
+from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
@@ -146,7 +147,7 @@
output_dir: Optional[str] = None,
batch_size: int = 1,
dtype: str = "float32",
- ngpu: int = 0,
+ ngpu: int = 1,
num_workers: int = 0,
log_level: Union[int, str] = "INFO",
key_file: Optional[str] = None,
@@ -179,7 +180,6 @@
diar_model_file=diar_model_file,
device=device,
dtype=dtype,
- streaming=streaming,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
speech2diar = Speech2Diarization.from_pretrained(
@@ -209,7 +209,7 @@
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
loader = EENDOLADiarTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
@@ -236,9 +236,23 @@
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
results = speech2diar(**batch)
+
+ # post process
+ a = results[0][0].cpu().numpy()
+ a = medfilt(a, (11, 1))
+ rst = []
+ for spkid, frames in enumerate(a.T):
+ frames = np.pad(frames, (1, 1), 'constant')
+ changes, = np.where(np.diff(frames, axis=0) != 0)
+ fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
+ for s, e in zip(changes[::2], changes[1::2]):
+ st = s / 10.
+ dur = (e - s) / 10.
+ rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
+
# Only supporting batch_size==1
- key, value = keys[0], output_results_str(results, keys[0])
- item = {"key": key, "value": value}
+ value = "\n".join(rst)
+ item = {"key": keys[0], "value": value}
result_list.append(item)
if output_path is not None:
output_writer.write(value)
--
Gitblit v1.9.1