speech_asr
2023-03-13 e27de5aa6bd9af2a82e80604978b50aa538493ec
funasr/models/e2e_diar_eend_ola.py
@@ -11,7 +11,8 @@
import torch.nn as  nn
from typeguard import check_argument_types
from funasr.modules.eend_ola.encoder import TransformerEncoder
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
@@ -34,12 +35,13 @@
class DiarEENDOLAModel(AbsESPnetModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    """EEND-OLA diarization model"""
    def __init__(
            self,
            encoder: TransformerEncoder,
            eda: EncoderDecoderAttractor,
            frontend: WavFrontendMel23,
            encoder: EENDOLATransformerEncoder,
            encoder_decoder_attractor: EncoderDecoderAttractor,
            n_units: int = 256,
            max_n_speaker: int = 8,
            attractor_loss_weight: float = 1.0,
@@ -49,8 +51,9 @@
        assert check_argument_types()
        super().__init__()
        self.frontend = frontend
        self.encoder = encoder
        self.eda = eda
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
@@ -187,16 +190,18 @@
                            shuffle: bool = True,
                            threshold: float = 0.5,
                            **kwargs):
        if self.frontend is not None:
            speech = self.frontend(speech)
        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
        emb = self.forward_encoder(speech, speech_lengths)
        if shuffle:
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.eda.estimate(
            attractors, probs = self.encoder_decoder_attractor.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.eda.estimate(emb)
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0: