speech_asr
2023-03-15 26b81480a88cc2868639c5160989394199acdcdd
funasr/models/e2e_diar_eend_ola.py
@@ -52,15 +52,15 @@
        super().__init__()
        self.frontend = frontend
        self.encoder = encoder
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.enc = encoder
        self.eda = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
            mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
            self.mapping_dict = mapping_dict
        # PostNet
        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
        self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
        self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
    def forward_encoder(self, xs, ilens):
@@ -68,7 +68,7 @@
        pad_shape = xs.shape
        xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
        xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
        emb = self.encoder(xs, xs_mask)
        emb = self.enc(xs, xs_mask)
        emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
        emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
        return emb
@@ -77,7 +77,7 @@
        maxlen = torch.max(ilens).to(torch.int).item()
        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
        outputs, (_, _) = self.PostNet(logits)
        outputs, (_, _) = self.postnet(logits)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
        outputs = [self.output_layer(output) for output in outputs]
@@ -112,7 +112,7 @@
        text = text[:, : text_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
@@ -198,10 +198,10 @@
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.encoder_decoder_attractor.estimate(
            attractors, probs = self.eda.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
            attractors, probs = self.eda.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0: