From d2dc3af1a69ee4075bcfc0c83dc0fb8e3fc1db4e Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 11 五月 2023 16:31:40 +0800
Subject: [PATCH] Merge pull request #492 from alibaba-damo-academy/dev_smohan
---
funasr/bin/sv_inference.py | 34 ++++++++++++++++++++--------------
1 files changed, 20 insertions(+), 14 deletions(-)
diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py
index 57ce91d..76b1dfb 100755
--- a/funasr/bin/sv_inference.py
+++ b/funasr/bin/sv_inference.py
@@ -1,4 +1,7 @@
#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import argparse
import logging
import os
@@ -26,14 +29,14 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-
+from funasr.utils.misc import statistic_model_parameters
class Speech2Xvector:
"""Speech2Xvector class
Examples:
>>> import soundfile
- >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pth")
+ >>> speech2xvector = Speech2Xvector("sv_config.yml", "sv.pb")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2xvector(audio)
[(text, token, token_int, hypothesis object), ...]
@@ -59,6 +62,7 @@
device=device
)
logging.info("sv_model: {}".format(sv_model))
+ logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
logging.info("sv_train_args: {}".format(sv_train_args))
sv_model.to(dtype=getattr(torch, dtype)).eval()
@@ -156,17 +160,17 @@
def inference_modelscope(
- output_dir: Optional[str],
- batch_size: int,
- dtype: str,
- ngpu: int,
- seed: int,
- num_workers: int,
- log_level: Union[int, str],
- key_file: Optional[str],
- sv_train_config: Optional[str],
- sv_model_file: Optional[str],
- model_tag: Optional[str],
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ sv_train_config: Optional[str] = "sv.yaml",
+ sv_model_file: Optional[str] = "sv.pb",
+ model_tag: Optional[str] = None,
allow_variable_data_keys: bool = True,
streaming: bool = False,
embedding_node: str = "resnet1_dense",
@@ -175,6 +179,9 @@
**kwargs,
):
assert check_argument_types()
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
@@ -214,7 +221,6 @@
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
- fs: dict = None,
param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
--
Gitblit v1.9.1