From e1ba6bc138b4e73875c64f35f98f3b15a0560e92 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 17 五月 2023 15:16:06 +0800
Subject: [PATCH] Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer
---
funasr/bin/diar_inference_launch.py | 51 ++++++++++++++++++++++++++-------------------------
1 files changed, 26 insertions(+), 25 deletions(-)
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 08004e8..e0d900e 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
@@ -37,7 +38,6 @@
from scipy.signal import medfilt
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
-from funasr.tasks.asr import ASRTask
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
@@ -186,7 +186,7 @@
raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
else:
# 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
+ loader = DiarTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
@@ -362,6 +362,30 @@
return _forward
+
+
+def inference_launch(mode, **kwargs):
+ if mode == "sond":
+ return inference_sond(mode=mode, **kwargs)
+ elif mode == "sond_demo":
+ param_dict = {
+ "extract_profile": True,
+ "sv_train_config": "sv.yaml",
+ "sv_model_file": "sv.pb",
+ }
+ if "param_dict" in kwargs and kwargs["param_dict"] is not None:
+ for key in param_dict:
+ if key not in kwargs["param_dict"]:
+ kwargs["param_dict"][key] = param_dict[key]
+ else:
+ kwargs["param_dict"] = param_dict
+ return inference_sond(mode=mode, **kwargs)
+ elif mode == "eend-ola":
+ return inference_eend(mode=mode, **kwargs)
+ else:
+ logging.info("Unknown decoding mode: {}".format(mode))
+ return None
+
def get_parser():
parser = config_argparse.ArgumentParser(
description="Speaker Verification",
@@ -469,29 +493,6 @@
)
return parser
-
-
-def inference_launch(mode, **kwargs):
- if mode == "sond":
- return inference_sond(mode=mode, **kwargs)
- elif mode == "sond_demo":
- param_dict = {
- "extract_profile": True,
- "sv_train_config": "sv.yaml",
- "sv_model_file": "sv.pb",
- }
- if "param_dict" in kwargs and kwargs["param_dict"] is not None:
- for key in param_dict:
- if key not in kwargs["param_dict"]:
- kwargs["param_dict"][key] = param_dict[key]
- else:
- kwargs["param_dict"] = param_dict
- return inference_sond(mode=mode, **kwargs)
- elif mode == "eend-ola":
- return inference_eend(mode=mode, **kwargs)
- else:
- logging.info("Unknown decoding mode: {}".format(mode))
- return None
def main(cmd=None):
--
Gitblit v1.9.1