From 2ff405b2f4ab899eff9bece232969fbb0c8f0555 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 20 六月 2023 00:26:37 +0800
Subject: [PATCH] Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer
---
funasr/bin/sv_inference_launch.py | 106 ++++++++++++++++++++++-------------------------------
1 files changed, 44 insertions(+), 62 deletions(-)
diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index dbddd9f..d165736 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,5 +1,5 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -7,20 +7,6 @@
import logging
import os
import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-import argparse
-import logging
-import os
-import sys
-from pathlib import Path
-from typing import Any
-from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -30,61 +16,59 @@
import torch
from kaldiio import WriteHelper
from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
-from funasr.torch_utils.device_funcs import to_device
+from funasr.bin.sv_infer import Speech2Xvector
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.utils.misc import statistic_model_parameters
-from funasr.bin.sv_infer import Speech2Xvector
+
def inference_sv(
- output_dir: Optional[str] = None,
- batch_size: int = 1,
- dtype: str = "float32",
- ngpu: int = 1,
- seed: int = 0,
- num_workers: int = 0,
- log_level: Union[int, str] = "INFO",
- key_file: Optional[str] = None,
- sv_train_config: Optional[str] = "sv.yaml",
- sv_model_file: Optional[str] = "sv.pb",
- model_tag: Optional[str] = None,
- allow_variable_data_keys: bool = True,
- streaming: bool = False,
- embedding_node: str = "resnet1_dense",
- sv_threshold: float = 0.9465,
- param_dict: Optional[dict] = None,
- **kwargs,
+ output_dir: Optional[str] = None,
+ batch_size: int = 1,
+ dtype: str = "float32",
+ ngpu: int = 1,
+ seed: int = 0,
+ num_workers: int = 0,
+ log_level: Union[int, str] = "INFO",
+ key_file: Optional[str] = None,
+ sv_train_config: Optional[str] = "sv.yaml",
+ sv_model_file: Optional[str] = "sv.pb",
+ model_tag: Optional[str] = None,
+ allow_variable_data_keys: bool = True,
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ sv_threshold: float = 0.9465,
+ param_dict: Optional[dict] = None,
+ **kwargs,
):
assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
-
+
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
-
+
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info("param_dict: {}".format(param_dict))
-
+
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
-
+
# 1. Set random-seed
set_all_random_seed(seed)
-
+
# 2. Build speech2xvector
speech2xvector_kwargs = dict(
sv_train_config=sv_train_config,
@@ -100,32 +84,31 @@
**speech2xvector_kwargs,
)
speech2xvector.sv_model.eval()
-
+
def _forward(
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- param_dict: Optional[dict] = None,
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ param_dict: Optional[dict] = None,
):
logging.info("param_dict: {}".format(param_dict))
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-
+
# 3. Build data-iterator
- loader = SVTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="sv",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=None,
- collate_fn=None,
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
+ use_collate_fn=False,
)
-
+
# 7 .Start for-loop
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
embd_writer, ref_embd_writer, score_writer = None, None, None
@@ -139,7 +122,7 @@
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
-
+
embedding, ref_embedding, score = speech2xvector(**batch)
# Only supporting batch_size==1
key = keys[0]
@@ -161,18 +144,16 @@
score_writer = open(os.path.join(output_path, "score.txt"), "w")
ref_embd_writer(key, ref_embedding[0].cpu().numpy())
score_writer.write("{} {:.6f}\n".format(key, normalized_score))
-
+
if output_path is not None:
embd_writer.close()
if ref_embd_writer is not None:
ref_embd_writer.close()
score_writer.close()
-
+
return sv_result_list
-
+
return _forward
-
-
def inference_launch(mode, **kwargs):
@@ -182,6 +163,7 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
--
Gitblit v1.9.1