From e1ba6bc138b4e73875c64f35f98f3b15a0560e92 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 17 五月 2023 15:16:06 +0800
Subject: [PATCH] Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer

---
 funasr/bin/sv_inference_launch.py |   22 +++++++++++-----------
 1 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py
index 24b8638..dbddd9f 100755
--- a/funasr/bin/sv_inference_launch.py
+++ b/funasr/bin/sv_inference_launch.py
@@ -1,7 +1,7 @@
+# -*- encoding: utf-8 -*-
 #!/usr/bin/env python3
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
-
 
 import argparse
 import logging
@@ -34,7 +34,6 @@
 
 from funasr.utils.cli_utils import get_commandline_args
 from funasr.tasks.sv import SVTask
-from funasr.tasks.asr import ASRTask
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
@@ -115,7 +114,7 @@
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
         
         # 3. Build data-iterator
-        loader = ASRTask.build_streaming_iterator(
+        loader = SVTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
@@ -173,6 +172,15 @@
     
     return _forward
 
+
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "sv":
+        return inference_sv(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
 
 def get_parser():
     parser = config_argparse.ArgumentParser(
@@ -287,14 +295,6 @@
     )
 
     return parser
-
-
-def inference_launch(mode, **kwargs):
-    if mode == "sv":
-        return inference_sv(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
 
 
 def main(cmd=None):

--
Gitblit v1.9.1