| | |
| | | import torch |
| | | from scipy.ndimage import median_filter |
| | | from torch.nn import functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.tasks.diar import DiarTask |
| | |
| | | device: str = "cpu", |
| | | dtype: str = "float32", |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build Diarization model |
| | | diar_model, diar_train_args = build_model_from_file( |
| | |
| | | diarization results |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | |
| | | results = self.diar_model.estimate_sequential(**batch) |
| | | |
| | | return results |
| | | |
| | | @staticmethod |
| | | def from_pretrained( |
| | | model_tag: Optional[str] = None, |
| | | **kwargs: Optional[Any], |
| | | ): |
| | | """Build Speech2Diarization instance from the pretrained model. |
| | | |
| | | Args: |
| | | model_tag (Optional[str]): Model tag of the pretrained models. |
| | | Currently, the tags of espnet_model_zoo are supported. |
| | | |
| | | Returns: |
| | | Speech2Diarization: Speech2Diarization instance. |
| | | |
| | | """ |
| | | if model_tag is not None: |
| | | try: |
| | | from espnet_model_zoo.downloader import ModelDownloader |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "`espnet_model_zoo` is not installed. " |
| | | "Please install via `pip install -U espnet_model_zoo`." |
| | | ) |
| | | raise |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2DiarizationEEND(**kwargs) |
| | | |
| | | |
| | | class Speech2DiarizationSOND: |
| | |
| | | smooth_size: int = 83, |
| | | dur_threshold: float = 10, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | # TODO: 1. Build Diarization model |
| | | diar_model, diar_train_args = build_model_from_file( |
| | |
| | | diarization results for each speaker |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | |
| | | results, pse_labels = self.post_processing(logits, profile.shape[1], output_format) |
| | | |
| | | return results, pse_labels |
| | | |
| | | @staticmethod |
| | | def from_pretrained( |
| | | model_tag: Optional[str] = None, |
| | | **kwargs: Optional[Any], |
| | | ): |
| | | """Build Speech2Xvector instance from the pretrained model. |
| | | |
| | | Args: |
| | | model_tag (Optional[str]): Model tag of the pretrained models. |
| | | Currently, the tags of espnet_model_zoo are supported. |
| | | |
| | | Returns: |
| | | Speech2Xvector: Speech2Xvector instance. |
| | | |
| | | """ |
| | | if model_tag is not None: |
| | | try: |
| | | from espnet_model_zoo.downloader import ModelDownloader |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "`espnet_model_zoo` is not installed. " |
| | | "Please install via `pip install -U espnet_model_zoo`." |
| | | ) |
| | | raise |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2DiarizationSOND(**kwargs) |