From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/diar_inference_launch.py | 26 ++++++++++----------------
1 files changed, 10 insertions(+), 16 deletions(-)
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 820217b..b655df5 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -18,7 +18,6 @@
import soundfile
import torch
from scipy.signal import medfilt
-from typeguard import check_argument_types
from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND
from funasr.datasets.iterable_dataset import load_bytes
@@ -52,7 +51,6 @@
mode: str = "sond",
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -94,10 +92,7 @@
embedding_node="resnet1_dense"
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector.from_pretrained(
- model_tag=model_tag,
- **speech2xvector_kwargs,
- )
+ speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
speech2xvector.sv_model.eval()
# 2b. Build speech2diar
@@ -111,10 +106,7 @@
dur_threshold=dur_threshold,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2DiarizationSOND.from_pretrained(
- model_tag=model_tag,
- **speech2diar_kwargs,
- )
+ speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
@@ -233,7 +225,6 @@
param_dict: Optional[dict] = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -260,10 +251,7 @@
dtype=dtype,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2DiarizationEEND.from_pretrained(
- model_tag=model_tag,
- **speech2diar_kwargs,
- )
+ speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
@@ -465,11 +453,17 @@
help="The batch size for inference",
)
group.add_argument(
- "--diar_smooth_size",
+ "--smooth_size",
type=int,
default=121,
help="The smoothing size for post-processing"
)
+ group.add_argument(
+ "--dur_threshold",
+ type=int,
+ default=10,
+ help="The threshold of minimum duration"
+ )
return parser
--
Gitblit v1.9.1