From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/bin/eend_ola_inference.py |   30 ++++++++++++++++++++----------
 1 files changed, 20 insertions(+), 10 deletions(-)

diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index 96e7516..01d3f29 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -16,6 +16,7 @@
 
 import numpy as np
 import torch
+from scipy.signal import medfilt
 from typeguard import check_argument_types
 
 from funasr.models.frontend.wav_frontend import WavFrontendMel23
@@ -27,8 +28,6 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 
-from modelscope.utils.logger import get_logger
-logger = get_logger()
 
 class Speech2Diarization:
     """Speech2Diarlization class
@@ -36,7 +35,7 @@
     Examples:
         >>> import soundfile
         >>> import numpy as np
-        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
+        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
         >>> profile = np.load("profiles.npy")
         >>> audio, rate = soundfile.read("speech.wav")
         >>> speech2diar(audio, profile)
@@ -148,7 +147,7 @@
         output_dir: Optional[str] = None,
         batch_size: int = 1,
         dtype: str = "float32",
-        ngpu: int = 0,
+        ngpu: int = 1,
         num_workers: int = 0,
         log_level: Union[int, str] = "INFO",
         key_file: Optional[str] = None,
@@ -210,8 +209,7 @@
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        logger.info(data_path_and_name_and_type)
+            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
         loader = EENDOLADiarTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
@@ -231,8 +229,6 @@
             output_writer = open("{}/result.txt".format(output_path), "w")
         result_list = []
         for keys, batch in loader:
-            logger.info("keys: {}".format(keys))
-            logger.info("batch: {}".format(batch))
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
             _bs = len(next(iter(batch.values())))
@@ -240,9 +236,23 @@
             # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
 
             results = speech2diar(**batch)
+
+            # post process
+            a = results[0][0].cpu().numpy()
+            a = medfilt(a, (11, 1))
+            rst = []
+            for spkid, frames in enumerate(a.T):
+                frames = np.pad(frames, (1, 1), 'constant')
+                changes, = np.where(np.diff(frames, axis=0) != 0)
+                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
+                for s, e in zip(changes[::2], changes[1::2]):
+                    st = s / 10.
+                    dur = (e - s) / 10.
+                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
+
             # Only supporting batch_size==1
-            key, value = keys[0], output_results_str(results, keys[0])
-            item = {"key": key, "value": value}
+            value = "\n".join(rst)
+            item = {"key": keys[0], "value": value}
             result_list.append(item)
             if output_path is not None:
                 output_writer.write(value)

--
Gitblit v1.9.1