Chong Zhang
2023-05-23 5fec3c9e58fceda85fa2daf7deec2492372dac8a
funasr/bin/diar_inference_launch.py
@@ -1,3 +1,4 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -37,7 +38,6 @@
from scipy.signal import medfilt
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.diar import DiarTask
from funasr.tasks.asr import ASRTask
from funasr.tasks.diar import EENDOLADiarTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
@@ -186,7 +186,7 @@
                raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ")
        else:
            # 3. Build data-iterator
            loader = ASRTask.build_streaming_iterator(
            loader = DiarTask.build_streaming_iterator(
                data_path_and_name_and_type,
                dtype=dtype,
                batch_size=batch_size,
@@ -362,6 +362,30 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "sond":
        return inference_sond(mode=mode, **kwargs)
    elif mode == "sond_demo":
        param_dict = {
            "extract_profile": True,
            "sv_train_config": "sv.yaml",
            "sv_model_file": "sv.pb",
        }
        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
            for key in param_dict:
                if key not in kwargs["param_dict"]:
                    kwargs["param_dict"][key] = param_dict[key]
        else:
            kwargs["param_dict"] = param_dict
        return inference_sond(mode=mode, **kwargs)
    elif mode == "eend-ola":
        return inference_eend(mode=mode, **kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Speaker Verification",
@@ -469,29 +493,6 @@
    )
    return parser
def inference_launch(mode, **kwargs):
    if mode == "sond":
        return inference_sond(mode=mode, **kwargs)
    elif mode == "sond_demo":
        param_dict = {
            "extract_profile": True,
            "sv_train_config": "sv.yaml",
            "sv_model_file": "sv.pb",
        }
        if "param_dict" in kwargs and kwargs["param_dict"] is not None:
            for key in param_dict:
                if key not in kwargs["param_dict"]:
                    kwargs["param_dict"][key] = param_dict[key]
        else:
            kwargs["param_dict"] = param_dict
        return inference_sond(mode=mode, **kwargs)
    elif mode == "eend-ola":
        return inference_eend(mode=mode, **kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):