From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky
---
funasr/bin/eend_ola_inference.py | 13 +++++++------
1 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index b35824a..01d3f29 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -16,8 +16,8 @@
import numpy as np
import torch
-from typeguard import check_argument_types
from scipy.signal import medfilt
+from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import EENDOLADiarTask
@@ -28,13 +28,14 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+
class Speech2Diarization:
"""Speech2Diarlization class
Examples:
>>> import soundfile
>>> import numpy as np
- >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
+ >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
>>> profile = np.load("profiles.npy")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2diar(audio, profile)
@@ -208,7 +209,7 @@
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
+ data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
loader = EENDOLADiarTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
@@ -237,7 +238,7 @@
results = speech2diar(**batch)
# post process
- a = results[0].cpu().numpy()
+ a = results[0][0].cpu().numpy()
a = medfilt(a, (11, 1))
rst = []
for spkid, frames in enumerate(a.T):
@@ -246,8 +247,8 @@
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
for s, e in zip(changes[::2], changes[1::2]):
st = s / 10.
- ed = e / 10.
- rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
+ dur = (e - s) / 10.
+ rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
# Only supporting batch_size==1
value = "\n".join(rst)
--
Gitblit v1.9.1