From f6a1cdaf3488c9ec572e1f753b50cb58a0f8fd79 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 10 二月 2023 18:56:14 +0800
Subject: [PATCH] add sond model

---
 funasr/bin/sv_inference.py |   29 ++++++++++++++++-------------
 1 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
index 57ce91d..a78bccd 100755
--- a/funasr/bin/sv_inference.py
+++ b/funasr/bin/sv_inference.py
@@ -1,4 +1,7 @@
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import argparse
 import logging
 import os
@@ -26,7 +29,7 @@
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-
+from funasr.utils.misc import statistic_model_parameters
 
 class Speech2Xvector:
     """Speech2Xvector class
@@ -59,6 +62,7 @@
             device=device
         )
         logging.info("sv_model: {}".format(sv_model))
+        logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
         logging.info("sv_train_args: {}".format(sv_train_args))
         sv_model.to(dtype=getattr(torch, dtype)).eval()
 
@@ -156,17 +160,17 @@
 
 
 def inference_modelscope(
-        output_dir: Optional[str],
-        batch_size: int,
-        dtype: str,
-        ngpu: int,
-        seed: int,
-        num_workers: int,
-        log_level: Union[int, str],
-        key_file: Optional[str],
-        sv_train_config: Optional[str],
-        sv_model_file: Optional[str],
-        model_tag: Optional[str],
+        output_dir: Optional[str] = None,
+        batch_size: int = 1,
+        dtype: str = "float32",
+        ngpu: int = 1,
+        seed: int = 0,
+        num_workers: int = 0,
+        log_level: Union[int, str] = "INFO",
+        key_file: Optional[str] = None,
+        sv_train_config: Optional[str] = "sv.yaml",
+        sv_model_file: Optional[str] =  "sv.pth",
+        model_tag: Optional[str] = None,
         allow_variable_data_keys: bool = True,
         streaming: bool = False,
         embedding_node: str = "resnet1_dense",
@@ -214,7 +218,6 @@
             data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
             raw_inputs: Union[np.ndarray, torch.Tensor] = None,
             output_dir_v2: Optional[str] = None,
-            fs: dict = None,
             param_dict: Optional[dict] = None,
     ):
         logging.info("param_dict: {}".format(param_dict))

--
Gitblit v1.9.1