Chong Zhang
2023-05-23 5fec3c9e58fceda85fa2daf7deec2492372dac8a
funasr/bin/sv_inference_launch.py
@@ -1,7 +1,7 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -34,7 +34,6 @@
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.sv import SVTask
from funasr.tasks.asr import ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
@@ -115,7 +114,7 @@
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        
        # 3. Build data-iterator
        loader = ASRTask.build_streaming_iterator(
        loader = SVTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
@@ -173,6 +172,15 @@
    
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sv":
        return inference_sv(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -287,14 +295,6 @@
    )
    return parser
def inference_launch(mode, **kwargs):
    if mode == "sv":
        return inference_sv(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):