| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | |
| | | from scipy.signal import medfilt |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.tasks.diar import DiarTask |
| | | from funasr.tasks.asr import ASRTask |
| | | from funasr.tasks.diar import EENDOLADiarTask |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | |
| | | raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ") |
| | | else: |
| | | # 3. Build data-iterator |
| | | loader = ASRTask.build_streaming_iterator( |
| | | loader = DiarTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sond": |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "sond_demo": |
| | | param_dict = { |
| | | "extract_profile": True, |
| | | "sv_train_config": "sv.yaml", |
| | | "sv_model_file": "sv.pb", |
| | | } |
| | | if "param_dict" in kwargs and kwargs["param_dict"] is not None: |
| | | for key in param_dict: |
| | | if key not in kwargs["param_dict"]: |
| | | kwargs["param_dict"][key] = param_dict[key] |
| | | else: |
| | | kwargs["param_dict"] = param_dict |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "eend-ola": |
| | | return inference_eend(mode=mode, **kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Speaker Verification", |
| | |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sond": |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "sond_demo": |
| | | param_dict = { |
| | | "extract_profile": True, |
| | | "sv_train_config": "sv.yaml", |
| | | "sv_model_file": "sv.pb", |
| | | } |
| | | if "param_dict" in kwargs and kwargs["param_dict"] is not None: |
| | | for key in param_dict: |
| | | if key not in kwargs["param_dict"]: |
| | | kwargs["param_dict"][key] = param_dict[key] |
| | | else: |
| | | kwargs["param_dict"] = param_dict |
| | | return inference_sond(mode=mode, **kwargs) |
| | | elif mode == "eend-ola": |
| | | return inference_eend(mode=mode, **kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |