| | |
| | | from typing import Union |
| | | |
| | | import numpy as np |
| | | import soundfile |
| | | # import librosa |
| | | import librosa |
| | | import torch |
| | | from scipy.signal import medfilt |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.bin.diar_infer import Speech2DiarizationSOND, Speech2DiarizationEEND |
| | | from funasr.datasets.iterable_dataset import load_bytes |
| | |
| | | mode: str = "sond", |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | if batch_size > 1: |
| | |
| | | embedding_node="resnet1_dense" |
| | | ) |
| | | logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs)) |
| | | speech2xvector = Speech2Xvector.from_pretrained( |
| | | model_tag=model_tag, |
| | | **speech2xvector_kwargs, |
| | | ) |
| | | speech2xvector = Speech2Xvector(**speech2xvector_kwargs) |
| | | speech2xvector.sv_model.eval() |
| | | |
| | | # 2b. Build speech2diar |
| | |
| | | dur_threshold=dur_threshold, |
| | | ) |
| | | logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs)) |
| | | speech2diar = Speech2DiarizationSOND.from_pretrained( |
| | | model_tag=model_tag, |
| | | **speech2diar_kwargs, |
| | | ) |
| | | speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs) |
| | | speech2diar.diar_model.eval() |
| | | |
| | | def output_results_str(results: dict, uttid: str): |
| | |
| | | # read waveform file |
| | | example = [load_bytes(x) if isinstance(x, bytes) else x |
| | | for x in example] |
| | | example = [soundfile.read(x)[0] if isinstance(x, str) else x |
| | | # example = [librosa.load(x)[0] if isinstance(x, str) else x |
| | | # for x in example] |
| | | example = [librosa.load(x, dtype='float32')[0] if isinstance(x, str) else x |
| | | for x in example] |
| | | # convert torch tensor to numpy array |
| | | example = [x.numpy() if isinstance(example[0], torch.Tensor) else x |
| | |
| | | param_dict: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | if batch_size > 1: |
| | |
| | | dtype=dtype, |
| | | ) |
| | | logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs)) |
| | | speech2diar = Speech2DiarizationEEND.from_pretrained( |
| | | model_tag=model_tag, |
| | | **speech2diar_kwargs, |
| | | ) |
| | | speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs) |
| | | speech2diar.diar_model.eval() |
| | | |
| | | def output_results_str(results: dict, uttid: str): |
| | |
| | | help="The batch size for inference", |
| | | ) |
| | | group.add_argument( |
| | | "--diar_smooth_size", |
| | | "--smooth_size", |
| | | type=int, |
| | | default=121, |
| | | help="The smoothing size for post-processing" |
| | | ) |
| | | group.add_argument( |
| | | "--dur_threshold", |
| | | type=int, |
| | | default=10, |
| | | help="The threshold of minimum duration" |
| | | ) |
| | | |
| | | return parser |
| | | |