From 2ff405b2f4ab899eff9bece232969fbb0c8f0555 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 20 六月 2023 00:26:37 +0800
Subject: [PATCH] Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer

---
 funasr/bin/sv_infer.py |   28 ++++++++--------------------
 1 files changed, 8 insertions(+), 20 deletions(-)

diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 1517bfa..6e861da 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -1,35 +1,24 @@
-# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
-import argparse
 import logging
-import os
-import sys
 from pathlib import Path
 from typing import Any
-from typing import List
 from typing import Optional
-from typing import Sequence
 from typing import Tuple
 from typing import Union
 
 import numpy as np
 import torch
-from kaldiio import WriteHelper
 from typeguard import check_argument_types
 from typeguard import check_return_type
 
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
 from funasr.utils.misc import statistic_model_parameters
+
 
 class Speech2Xvector:
     """Speech2Xvector class
@@ -56,10 +45,13 @@
         assert check_argument_types()
 
         # TODO: 1. Build SV model
-        sv_model, sv_train_args = SVTask.build_model_from_file(
+        sv_model, sv_train_args = build_model_from_file(
             config_file=sv_train_config,
             model_file=sv_model_file,
-            device=device
+            cmvn_file=None,
+            device=device,
+            task_name="sv",
+            mode="sv",
         )
         logging.info("sv_model: {}".format(sv_model))
         logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
@@ -157,7 +149,3 @@
             kwargs.update(**d.download_and_unpack(model_tag))
 
         return Speech2Xvector(**kwargs)
-
-
-
-

--
Gitblit v1.9.1