From c3bce4c288f73a3bbf5559b019d4480f95acffaa Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 16 三月 2023 10:44:15 +0800
Subject: [PATCH] update

---
 funasr/bin/eend_ola_inference.py |   11 ++++++-----
 1 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index b35824a..bc29fa2 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -16,8 +16,8 @@
 
 import numpy as np
 import torch
-from typeguard import check_argument_types
 from scipy.signal import medfilt
+from typeguard import check_argument_types
 
 from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.tasks.diar import EENDOLADiarTask
@@ -27,6 +27,7 @@
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
+
 
 class Speech2Diarization:
     """Speech2Diarlization class
@@ -208,7 +209,7 @@
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
+            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
         loader = EENDOLADiarTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
@@ -237,7 +238,7 @@
             results = speech2diar(**batch)
 
             # post process
-            a = results[0].cpu().numpy()
+            a = results[0][0].cpu().numpy()
             a = medfilt(a, (11, 1))
             rst = []
             for spkid, frames in enumerate(a.T):
@@ -246,8 +247,8 @@
                 fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
                 for s, e in zip(changes[::2], changes[1::2]):
                     st = s / 10.
-                    ed = e / 10.
-                    rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
+                    dur = (e - s) / 10.
+                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
 
             # Only supporting batch_size==1
             value = "\n".join(rst)

--
Gitblit v1.9.1