| | |
| | | # -*- encoding: utf-8 -*- |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from collections import OrderedDict |
| | | from pathlib import Path |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | from collections import OrderedDict |
| | | import numpy as np |
| | | import soundfile |
| | | import torch |
| | | from scipy.ndimage import median_filter |
| | | from torch.nn import functional as F |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.tasks.diar import DiarTask |
| | | from funasr.tasks.diar import EENDOLADiarTask |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from scipy.ndimage import median_filter |
| | | from funasr.utils.misc import statistic_model_parameters |
| | | from funasr.datasets.iterable_dataset import load_bytes |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.tasks.diar import DiarTask |
| | | from funasr.build_utils.build_model_from_file import build_model_from_file |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.utils.misc import statistic_model_parameters |
| | | |
| | | |
| | | class Speech2DiarizationEEND: |
| | | """Speech2Diarlization class |
| | |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build Diarization model |
| | | diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file( |
| | | diar_model, diar_train_args = build_model_from_file( |
| | | config_file=diar_train_config, |
| | | model_file=diar_model_file, |
| | | device=device |
| | | device=device, |
| | | task_name="diar", |
| | | mode="eend-ola", |
| | | ) |
| | | frontend = None |
| | | if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None: |
| | |
| | | assert check_argument_types() |
| | | |
| | | # TODO: 1. Build Diarization model |
| | | diar_model, diar_train_args = DiarTask.build_model_from_file( |
| | | diar_model, diar_train_args = build_model_from_file( |
| | | config_file=diar_train_config, |
| | | model_file=diar_model_file, |
| | | device=device |
| | | device=device, |
| | | task_name="diar", |
| | | mode="sond", |
| | | ) |
| | | logging.info("diar_model: {}".format(diar_model)) |
| | | logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model))) |
| | |
| | | ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio |
| | | logits_idx = F.upsample( |
| | | logits_idx.unsqueeze(1).float(), |
| | | size=(ut, ), |
| | | size=(ut,), |
| | | mode="nearest", |
| | | ).squeeze(1).long() |
| | | logits_idx = logits_idx[0].tolist() |
| | |
| | | if spk not in results: |
| | | results[spk] = [] |
| | | if dur > self.dur_threshold: |
| | | results[spk].append((st, st+dur)) |
| | | results[spk].append((st, st + dur)) |
| | | |
| | | # sort segments in start time ascending |
| | | for spk in results: |
| | |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2DiarizationSOND(**kwargs) |
| | | |
| | | |
| | | |
| | | |