From 2ff405b2f4ab899eff9bece232969fbb0c8f0555 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 20 六月 2023 00:26:37 +0800
Subject: [PATCH] Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer
---
funasr/bin/sv_infer.py | 28 ++++++++--------------------
1 files changed, 8 insertions(+), 20 deletions(-)
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index 1517bfa..6e861da 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -1,35 +1,24 @@
-# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-import argparse
import logging
-import os
-import sys
from pathlib import Path
from typing import Any
-from typing import List
from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
import numpy as np
import torch
-from kaldiio import WriteHelper
from typeguard import check_argument_types
from typeguard import check_return_type
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.tasks.sv import SVTask
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
from funasr.utils.misc import statistic_model_parameters
+
class Speech2Xvector:
"""Speech2Xvector class
@@ -56,10 +45,13 @@
assert check_argument_types()
# TODO: 1. Build SV model
- sv_model, sv_train_args = SVTask.build_model_from_file(
+ sv_model, sv_train_args = build_model_from_file(
config_file=sv_train_config,
model_file=sv_model_file,
- device=device
+ cmvn_file=None,
+ device=device,
+ task_name="sv",
+ mode="sv",
)
logging.info("sv_model: {}".format(sv_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
@@ -157,7 +149,3 @@
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Xvector(**kwargs)
-
-
-
-
--
Gitblit v1.9.1