| | |
| | | import torch.nn as nn |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.eend_ola.encoder import TransformerEncoder |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.modules.eend_ola.utils.power import generate_mapping_dict |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | |
| | | |
| | | |
| | | class DiarEENDOLAModel(AbsESPnetModel): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | """EEND-OLA diarization model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | encoder: TransformerEncoder, |
| | | eda: EncoderDecoderAttractor, |
| | | frontend: WavFrontendMel23, |
| | | encoder: EENDOLATransformerEncoder, |
| | | encoder_decoder_attractor: EncoderDecoderAttractor, |
| | | n_units: int = 256, |
| | | max_n_speaker: int = 8, |
| | | attractor_loss_weight: float = 1.0, |
| | |
| | | assert check_argument_types() |
| | | |
| | | super().__init__() |
| | | self.frontend = frontend |
| | | self.encoder = encoder |
| | | self.eda = eda |
| | | self.encoder_decoder_attractor = encoder_decoder_attractor |
| | | self.attractor_loss_weight = attractor_loss_weight |
| | | self.max_n_speaker = max_n_speaker |
| | | if mapping_dict is None: |
| | |
| | | shuffle: bool = True, |
| | | threshold: float = 0.5, |
| | | **kwargs): |
| | | if self.frontend is not None: |
| | | speech = self.frontend(speech) |
| | | speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] |
| | | emb = self.forward_encoder(speech, speech_lengths) |
| | | if shuffle: |
| | | orders = [np.arange(e.shape[0]) for e in emb] |
| | | for order in orders: |
| | | np.random.shuffle(order) |
| | | attractors, probs = self.eda.estimate( |
| | | attractors, probs = self.encoder_decoder_attractor.estimate( |
| | | [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) |
| | | else: |
| | | attractors, probs = self.eda.estimate(emb) |
| | | attractors, probs = self.encoder_decoder_attractor.estimate(emb) |
| | | attractors_active = [] |
| | | for p, att, e in zip(probs, attractors, emb): |
| | | if n_speakers and n_speakers >= 0: |