雾聪
2023-08-07 f8d1c79fe355efb18ae49e4363307dfec3ab89ce
funasr/models/e2e_diar_eend_ola.py
@@ -1,21 +1,20 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Tuple
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as  nn
import torch.nn.functional as F
from funasr.models.base_model import FunASRModel
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
from funasr.modules.eend_ola.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    pass
@@ -33,12 +32,35 @@
    return att
def pad_labels(ts, out_size):
    for i, t in enumerate(ts):
        if t.shape[1] < out_size:
            ts[i] = F.pad(
                t,
                (0, out_size - t.shape[1], 0, 0),
                mode='constant',
                value=0.
            )
    return ts
def pad_results(ys, out_size):
    ys_padded = []
    for i, y in enumerate(ys):
        if y.shape[1] < out_size:
            ys_padded.append(
                torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
        else:
            ys_padded.append(y)
    return ys_padded
class DiarEENDOLAModel(FunASRModel):
    """EEND-OLA diarization model"""
    def __init__(
            self,
            frontend: WavFrontendMel23,
            frontend: Optional[WavFrontendMel23],
            encoder: EENDOLATransformerEncoder,
            encoder_decoder_attractor: EncoderDecoderAttractor,
            n_units: int = 256,
@@ -47,11 +69,10 @@
            mapping_dict=None,
            **kwargs,
    ):
        super().__init__()
        self.frontend = frontend
        self.enc = encoder
        self.eda = encoder_decoder_attractor
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
@@ -74,7 +95,8 @@
    def forward_post_net(self, logits, ilens):
        maxlen = torch.max(ilens).to(torch.int).item()
        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
                                                   enforce_sorted=False)
        outputs, (_, _) = self.postnet(logits)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@@ -83,95 +105,45 @@
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
            speech: List[torch.Tensor],
            speaker_labels: List[torch.Tensor],
            orders: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
        speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
        speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
        batch_size = len(speech)
        # for data-parallel
        text = text[:, : text_lengths.max()]
        # Encoder
        encoder_out = self.forward_encoder(speech, speech_lengths)
        # 1. Encoder
        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        # Encoder-decoder attractor
        attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
                                                                    speaker_labels_lengths)
        speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        # pit loss
        pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
        pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
        # pse loss
        with torch.no_grad():
            power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
                            to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
        pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
        pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
        pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
        pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
        loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
        stats = dict()
        # 1. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                               1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        stats["pse_loss"] = pse_loss.detach()
        stats["pit_loss"] = pit_loss.detach()
        stats["attractor_loss"] = attractor_loss.detach()
        stats["batch_size"] = batch_size
        # Collect total loss stats
        stats["loss"] = torch.clone(loss.detach())
@@ -182,21 +154,20 @@
    def estimate_sequential(self,
                            speech: torch.Tensor,
                            speech_lengths: torch.Tensor,
                            n_speakers: int = None,
                            shuffle: bool = True,
                            threshold: float = 0.5,
                            **kwargs):
        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
        speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
        emb = self.forward_encoder(speech, speech_lengths)
        if shuffle:
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.eda.estimate(
            attractors, probs = self.encoder_decoder_attractor.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.eda.estimate(emb)
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0: