| | |
| | | # !/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | |
| | | import logging |
| | | import os |
| | | import sys |
| | | from typing import Union, Dict, Any |
| | | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | from collections import OrderedDict |
| | | import numpy as np |
| | | import soundfile |
| | | import torch |
| | | from torch.nn import functional as F |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | from scipy.signal import medfilt |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.tasks.diar import DiarTask |
| | | from funasr.tasks.diar import EENDOLADiarTask |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND |
| | | from funasr.datasets.iterable_dataset import load_bytes |
| | | from funasr.build_utils.build_streaming_iterator import build_streaming_iterator |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from scipy.ndimage import median_filter |
| | | from funasr.utils.misc import statistic_model_parameters |
| | | from funasr.datasets.iterable_dataset import load_bytes |
| | | from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND |
| | | |
| | | |
| | | def inference_sond( |
| | | diar_train_config: str, |
| | |
| | | set_all_random_seed(seed) |
| | | |
| | | # 2a. Build speech2xvec [Optional] |
| | | if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict["extract_profile"]: |
| | | if mode == "sond_demo" and param_dict is not None and "extract_profile" in param_dict and param_dict[ |
| | | "extract_profile"]: |
| | | assert "sv_train_config" in param_dict, "sv_train_config must be provided param_dict." |
| | | assert "sv_model_file" in param_dict, "sv_model_file must be provided in param_dict." |
| | | sv_train_config = param_dict["sv_train_config"] |
| | |
| | | rst = [] |
| | | mid = uttid.rsplit("-", 1)[0] |
| | | for key in results: |
| | | results[key] = [(x[0]/100, x[1]/100) for x in results[key]] |
| | | results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]] |
| | | if out_format == "vad": |
| | | for spk, segs in results.items(): |
| | | rst.append("{} {}".format(spk, segs)) |
| | |
| | | example = [x.numpy() if isinstance(example[0], torch.Tensor) else x |
| | | for x in example] |
| | | speech = example[0] |
| | | logging.info("Extracting profiles for {} waveforms".format(len(example)-1)) |
| | | logging.info("Extracting profiles for {} waveforms".format(len(example) - 1)) |
| | | profile = [speech2xvector.calculate_embedding(x) for x in example[1:]] |
| | | profile = torch.cat(profile, dim=0) |
| | | yield ["test{}".format(idx)], {"speech": [speech], "profile": [profile]} |
| | |
| | | raise TypeError("raw_inputs must be a list or tuple in [speech, profile1, profile2, ...] ") |
| | | else: |
| | | # 3. Build data-iterator |
| | | loader = DiarTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | loader = build_streaming_iterator( |
| | | task_name="diar", |
| | | preprocess_args=None, |
| | | data_path_and_name_and_type=data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=None, |
| | | collate_fn=None, |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | use_collate_fn=False, |
| | | ) |
| | | |
| | | # 7. Start for-loop |
| | |
| | | return result_list |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def inference_eend( |
| | | diar_train_config: str, |
| | |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"] |
| | | loader = EENDOLADiarTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | loader = build_streaming_iterator( |
| | | task_name="diar", |
| | | preprocess_args=None, |
| | | data_path_and_name_and_type=data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False), |
| | | collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | # 3. Start for-loop |
| | |
| | | return _forward |
| | | |
| | | |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "sond": |
| | | return inference_sond(mode=mode, **kwargs) |
| | |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="Speaker Verification", |