游雁
2023-07-05 4e2fe544ae37174a3e09dfcdbbdae5abfe711e53
funasr/models/e2e_diar_sond.py
@@ -7,12 +7,11 @@
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Tuple, List
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -22,7 +21,9 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -33,9 +34,13 @@
        yield
class DiarSondModel(AbsESPnetModel):
    """Speaker overlap-aware neural diarization model
    reference: https://arxiv.org/abs/2211.10243
class DiarSondModel(FunASRModel):
    """
    Author: Speech Lab, Alibaba Group, China
    SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
    https://arxiv.org/abs/2211.10243
    TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
    https://arxiv.org/abs/2303.05397
    """
    def __init__(
@@ -44,19 +49,22 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: AbsEncoder,
        speaker_encoder: AbsEncoder,
        encoder: torch.nn.Module,
        speaker_encoder: Optional[torch.nn.Module],
        ci_scorer: torch.nn.Module,
        cd_scorer: torch.nn.Module,
        cd_scorer: Optional[torch.nn.Module],
        decoder: torch.nn.Module,
        token_list: list,
        lsm_weight: float = 0.1,
        length_normalized_loss: bool = False,
        max_spk_num: int = 16,
        label_aggregator: Optional[torch.nn.Module] = None,
        normlize_speech_speaker: bool = False,
        normalize_speech_speaker: bool = False,
        ignore_id: int = -1,
        speaker_discrimination_loss_weight: float = 1.0,
        inter_score_loss_weight: float = 0.0,
        inputs_type: str = "raw",
    ):
        assert check_argument_types()
        super().__init__()
@@ -71,7 +79,29 @@
        self.decoder = decoder
        self.token_list = token_list
        self.max_spk_num = max_spk_num
        self.normalize_speech_speaker = normlize_speech_speaker
        self.normalize_speech_speaker = normalize_speech_speaker
        self.ignore_id = ignore_id
        self.criterion_diar = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
        self.pse_embedding = self.generate_pse_embedding()
        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
        self.forward_steps = 0
        self.inputs_type = inputs_type
    def generate_pse_embedding(self):
        embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
        for idx, pse_label in enumerate(self.token_list):
            emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
            embedding[idx] = emb
        return torch.from_numpy(embedding)
    def forward(
        self,
@@ -79,13 +109,12 @@
        speech_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        binary_labels_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
        Args:
            speech: (Batch, samples)
            speech: (Batch, samples) or (Batch, frames, input_size)
            speech_lengths: (Batch,) default None for chunk interator,
                                     because the chunk-iterator does not
                                     have the speech_lengths returned.
@@ -93,63 +122,53 @@
                                     espnet2/iterators/chunk_iter_factory.py
            profile: (Batch, N_spk, dim)
            profile_lengths: (Batch,)
            spk_labels: (Batch, )
            binary_labels: (Batch, frames, max_spk_num)
            binary_labels_lengths: (Batch,)
        """
        assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
        assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
        batch_size = speech.shape[0]
        self.forward_steps = self.forward_steps + 1
        if self.pse_embedding.device != speech.device:
            self.pse_embedding = self.pse_embedding.to(speech.device)
            self.power_weight = self.power_weight.to(speech.device)
            self.int_token_arr = self.int_token_arr.to(speech.device)
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # 1. Network forward
        pred, inter_outputs = self.prediction_forward(
            speech, speech_lengths,
            profile, profile_lengths,
            return_inter_outputs=True
        )
        (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
        if self.attractor is None:
            # 2a. Decoder (baiscally a predction layer after encoder_out)
            pred = self.decoder(encoder_out, encoder_out_lens)
        else:
            # 2b. Encoder Decoder Attractors
            # Shuffle the chronological order of encoder_out, then calculate attractor
            encoder_out_shuffled = encoder_out.clone()
            for i in range(len(encoder_out_lens)):
                encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
                    i, torch.randperm(encoder_out_lens[i]), :
                ]
            attractor, att_prob = self.attractor(
                encoder_out_shuffled,
                encoder_out_lens,
                to_device(
                    self,
                    torch.zeros(
                        encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
                    ),
                ),
            )
            # Remove the final attractor which does not correspond to a speaker
            # Then multiply the attractors and encoder_out
            pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
        # 3. Aggregate time-domain labels
        # 2. Aggregate time-domain labels to match forward outputs
        if self.label_aggregator is not None:
            spk_labels, spk_labels_lengths = self.label_aggregator(
                spk_labels, spk_labels_lengths
            binary_labels, binary_labels_lengths = self.label_aggregator(
                binary_labels, binary_labels_lengths
            )
        # 2. Calculate power-set encoding (PSE) labels
        raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
        pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
        # If encoder uses conv* as input_layer (i.e., subsampling),
        # the sequence length of 'pred' might be slighly less than the
        # the sequence length of 'pred' might be slightly less than the
        # length of 'spk_labels'. Here we force them to be equal.
        length_diff_tolerance = 2
        length_diff = spk_labels.shape[1] - pred.shape[1]
        if length_diff > 0 and length_diff <= length_diff_tolerance:
            spk_labels = spk_labels[:, 0 : pred.shape[1], :]
        length_diff = abs(pse_labels.shape[1] - pred.shape[1])
        if length_diff <= length_diff_tolerance:
            min_len = min(pred.shape[1], pse_labels.shape[1])
            pse_labels = pse_labels[:, :min_len]
            pred = pred[:, :min_len]
            cd_score = cd_score[:, :min_len]
            ci_score = ci_score[:, :min_len]
        if self.attractor is None:
            loss_pit, loss_att = None, None
            loss, perm_idx, perm_list, label_perm = self.pit_loss(
                pred, spk_labels, encoder_out_lens
            )
        else:
            loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
                pred, spk_labels, encoder_out_lens
            )
            loss_att = self.attractor_loss(att_prob, spk_labels)
            loss = loss_pit + self.attractor_weight * loss_att
        loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
        loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
        label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
        loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
                + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
        (
            correct,
            num_frames,
@@ -160,7 +179,11 @@
            speaker_miss,
            speaker_falarm,
            speaker_error,
        ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
        ) = self.calc_diarization_error(
            pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding),
            label=F.embedding(pse_labels * (~label_mask), self.pse_embedding),
            length=binary_labels_lengths
        )
        if speech_scored > 0 and num_frames > 0:
            sad_mr, sad_fr, mi, fa, cf, acc, der = (
@@ -177,8 +200,10 @@
        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_pit=loss_pit.detach() if loss_pit is not None else None,
            loss_diar=loss_diar.detach() if loss_diar is not None else None,
            loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
            loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
            loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
            sad_mr=sad_mr,
            sad_fr=sad_fr,
            mi=mi,
@@ -186,17 +211,78 @@
            cf=cf,
            acc=acc,
            der=der,
            forward_steps=self.forward_steps,
        )
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def classification_loss(
            self,
            predictions: torch.Tensor,
            labels: torch.Tensor,
            prediction_lengths: torch.Tensor
    ) -> torch.Tensor:
        mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1])
        pad_labels = labels.masked_fill(
            mask.to(predictions.device),
            value=self.ignore_id
        )
        loss = self.criterion_diar(predictions.contiguous(), pad_labels)
        return loss
    def speaker_discrimination_loss(
            self,
            profile: torch.Tensor,
            profile_lengths: torch.Tensor
    ) -> torch.Tensor:
        profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()  # (B, N, 1)
        mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2))  # (B, N, N)
        mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask))
        eps = 1e-12
        coding_norm = torch.linalg.norm(
            profile * profile_mask + (1 - profile_mask) * eps,
            dim=2, keepdim=True
        ) * profile_mask
        # profile: Batch, N, dim
        cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask
        cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
        loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
        return loss
    def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
        mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1])
        padding_labels = pse_labels.masked_fill(
            mask.to(pse_labels.device),
            value=0
        ).to(pse_labels)
        multi_labels = F.embedding(padding_labels, self.pse_embedding)
        return multi_labels
    def internal_score_loss(
            self,
            cd_score: torch.Tensor,
            ci_score: torch.Tensor,
            pse_labels: torch.Tensor,
            pse_labels_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
        ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
        cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
        return ci_loss, cd_loss
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        binary_labels_lengths: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}
@@ -222,7 +308,7 @@
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.encoder is not None:
        if self.encoder is not None and self.inputs_type == "raw":
            speech, speech_lengths = self.encode(speech, speech_lengths)
            speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
            speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
@@ -249,7 +335,7 @@
            speaker_encoder_outputs: torch.Tensor,
            seq_len: torch.Tensor = None,
            spk_len: torch.Tensor = None,
    ) -> torch.Tensor:
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
        d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
        if self.normalize_speech_speaker:
@@ -265,11 +351,11 @@
        if isinstance(self.ci_scorer, AbsEncoder):
            ci_simi = self.ci_scorer(ge_in, ge_len)[0]
            ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
        else:
            ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
        simi = torch.cat([cd_simi, ci_simi], dim=2)
        return simi
        return ci_simi, cd_simi
    def post_net_forward(self, simi, seq_len):
        logits = self.decoder(simi, seq_len)[0]
@@ -282,23 +368,26 @@
            speech_lengths: torch.Tensor,
            profile: torch.Tensor,
            profile_lengths: torch.Tensor,
    ) -> torch.Tensor:
            return_inter_outputs: bool = False,
    ) -> [torch.Tensor, Optional[list]]:
        # speech encoding
        speech, speech_lengths = self.encode_speech(speech, speech_lengths)
        # speaker encoding
        profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
        # calculating similarity
        similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
        ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
        similarity = torch.cat([cd_simi, ci_simi], dim=2)
        # post net forward
        logits = self.post_net_forward(similarity, speech_lengths)
        if return_inter_outputs:
            return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)]
        return logits
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch,)
@@ -318,7 +407,8 @@
            # 4. Forward encoder
            # feats: (Batch, Length, Dim)
            # -> encoder_out: (Batch, Length2, Dim)
            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
            encoder_outputs = self.encoder(feats, feats_lengths)
            encoder_out, encoder_out_lens = encoder_outputs[:2]
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -363,9 +453,7 @@
        (batch_size, max_len, num_output) = label.size()
        # mask the padding part
        mask = np.zeros((batch_size, max_len, num_output))
        for i in range(batch_size):
            mask[i, : length[i], :] = 1
        mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy()
        # pred and label have the shape (batch_size, max_len, num_output)
        label_np = label.data.cpu().numpy().astype(int)
@@ -399,4 +487,4 @@
            speaker_miss,
            speaker_falarm,
            speaker_error,
        )
        )