From b5d3df75cf6462aa3bf42fd3c86fa2aa7f1c8a15 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 24 十一月 2023 00:54:44 +0800
Subject: [PATCH] setup jamo

---
 funasr/bin/diar_inference_launch.py |   33 +++++++++++++++------------------
 1 files changed, 15 insertions(+), 18 deletions(-)

diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 820217b..f5a11b1 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -15,10 +15,10 @@
 from typing import Union
 
 import numpy as np
-import soundfile
+# import librosa
+import librosa
 import torch
 from scipy.signal import medfilt
-from typeguard import check_argument_types
 
 from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
 from funasr.datasets.iterable_dataset import load_bytes
@@ -52,7 +52,6 @@
         mode: str = "sond",
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -94,10 +93,7 @@
             embedding_node="resnet1_dense"
         )
         logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
-        speech2xvector = Speech2Xvector.from_pretrained(
-            model_tag=model_tag,
-            **speech2xvector_kwargs,
-        )
+        speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
         speech2xvector.sv_model.eval()
 
     # 2b. Build speech2diar
@@ -111,10 +107,7 @@
         dur_threshold=dur_threshold,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationSOND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):
@@ -152,7 +145,9 @@
                         # read waveform file
                         example = [load_bytes(x) if isinstance(x, bytes) else x
                                    for x in example]
-                        example = [soundfile.read(x)[0] if isinstance(x, str) else x
+                        # example = [librosa.load(x)[0] if isinstance(x, str) else x
+                        #            for x in example]
+                        example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x
                                    for x in example]
                         # convert torch tensor to numpy array
                         example = [x.numpy() if isinstance(example[0], torch.Tensor) else x
@@ -233,7 +228,6 @@
         param_dict: Optional[dict] = None,
         **kwargs,
 ):
-    assert check_argument_types()
     ncpu = kwargs.get("ncpu", 1)
     torch.set_num_threads(ncpu)
     if batch_size > 1:
@@ -260,10 +254,7 @@
         dtype=dtype,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
-    speech2diar = Speech2DiarizationEEND.from_pretrained(
-        model_tag=model_tag,
-        **speech2diar_kwargs,
-    )
+    speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
     speech2diar.diar_model.eval()
 
     def output_results_str(results: dict, uttid: str):
@@ -465,11 +456,17 @@
         help="The batch size for inference",
     )
     group.add_argument(
-        "--diar_smooth_size",
+        "--smooth_size",
         type=int,
         default=121,
         help="The smoothing size for post-processing"
     )
+    group.add_argument(
+        "--dur_threshold",
+        type=int,
+        default=10,
+        help="The threshold of minimum duration"
+    )
 
     return parser
 

--
Gitblit v1.9.1