游雁
2023-09-13 33d3d2084403fd34b79c835d2f2fe04f6cd8f738
funasr/models/e2e_diar_sond.py
@@ -1,18 +1,18 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import logging
import random
from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Tuple, List
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -20,9 +20,13 @@
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.abs_profileaug import AbsProfileAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
from funasr.utils.hinter import hint_once
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -33,7 +37,7 @@
        yield
class DiarSondModel(AbsESPnetModel):
class DiarSondModel(FunASRModel):
    """Speaker overlap-aware neural diarization model
    reference: https://arxiv.org/abs/2211.10243
    """
@@ -43,20 +47,27 @@
        vocab_size: int,
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        profileaug: Optional[AbsProfileAug],
        normalize: Optional[AbsNormalize],
        encoder: AbsEncoder,
        speaker_encoder: AbsEncoder,
        encoder: torch.nn.Module,
        speaker_encoder: Optional[torch.nn.Module],
        ci_scorer: torch.nn.Module,
        cd_scorer: torch.nn.Module,
        cd_scorer: Optional[torch.nn.Module],
        decoder: torch.nn.Module,
        token_list: list,
        lsm_weight: float = 0.1,
        length_normalized_loss: bool = False,
        max_spk_num: int = 16,
        label_aggregator: Optional[torch.nn.Module] = None,
        normlize_speech_speaker: bool = False,
        normalize_speech_speaker: bool = False,
        ignore_id: int = -1,
        speaker_discrimination_loss_weight: float = 1.0,
        inter_score_loss_weight: float = 0.0,
        inputs_type: str = "raw",
        model_regularizer_weight: float = 0.0,
        freeze_encoder: bool = False,
        onfly_shuffle_speaker: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
@@ -67,11 +78,68 @@
        self.normalize = normalize
        self.frontend = frontend
        self.specaug = specaug
        self.profileaug = profileaug
        self.label_aggregator = label_aggregator
        self.decoder = decoder
        self.token_list = token_list
        self.max_spk_num = max_spk_num
        self.normalize_speech_speaker = normlize_speech_speaker
        self.normalize_speech_speaker = normalize_speech_speaker
        self.ignore_id = ignore_id
        self.model_regularizer_weight = model_regularizer_weight
        self.freeze_encoder = freeze_encoder
        self.onfly_shuffle_speaker = onfly_shuffle_speaker
        self.criterion_diar = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
        self.pse_embedding = self.generate_pse_embedding()
        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
        self.forward_steps = 0
        self.inputs_type = inputs_type
        self.to_regularize_parameters = None
    def get_regularize_parameters(self):
        to_regularize_parameters, normal_parameters = [], []
        for name, param in self.named_parameters():
            if ("encoder" in name and "weight" in name and "bn" not in name and
                ("conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name)
            ):
                to_regularize_parameters.append((name, param))
            else:
                normal_parameters.append((name, param))
        self.to_regularize_parameters = to_regularize_parameters
        return to_regularize_parameters, normal_parameters
    def generate_pse_embedding(self):
        embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32)
        for idx, pse_label in enumerate(self.token_list):
            emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32)
            embedding[idx] = emb
        return torch.from_numpy(embedding)
    def rand_permute_speaker(self, raw_profile, raw_binary_labels):
        """
        raw_profile: B, N, D
        raw_binary_labels: B, T, N
        """
        assert raw_profile.shape[1] == raw_binary_labels.shape[2], \
            "Num profile: {}, Num label: {}".format(raw_profile.shape[1], raw_binary_labels.shape[-1])
        profile = torch.clone(raw_profile)
        binary_labels = torch.clone(raw_binary_labels)
        bsz, num_spk = profile.shape[0], profile.shape[1]
        for i in range(bsz):
            idx = list(range(num_spk))
            random.shuffle(idx)
            profile[i] = profile[i][idx, :]
            binary_labels[i] = binary_labels[i][:, idx]
        return profile, binary_labels
    def forward(
        self,
@@ -79,13 +147,13 @@
        speech_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        binary_labels_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
        Args:
            speech: (Batch, samples)
            speech: (Batch, samples) or (Batch, frames, input_size)
            speech_lengths: (Batch,) default None for chunk interator,
                                     because the chunk-iterator does not
                                     have the speech_lengths returned.
@@ -93,63 +161,74 @@
                                     espnet2/iterators/chunk_iter_factory.py
            profile: (Batch, N_spk, dim)
            profile_lengths: (Batch,)
            spk_labels: (Batch, )
            binary_labels: (Batch, frames, max_spk_num)
            binary_labels_lengths: (Batch,)
        """
        assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
        assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
        batch_size = speech.shape[0]
        if self.freeze_encoder:
            hint_once("Freeze encoder", "freeze_encoder", rank=0)
            self.encoder.eval()
        self.forward_steps = self.forward_steps + 1
        if self.pse_embedding.device != speech.device:
            self.pse_embedding = self.pse_embedding.to(speech.device)
            self.power_weight = self.power_weight.to(speech.device)
            self.int_token_arr = self.int_token_arr.to(speech.device)
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if self.onfly_shuffle_speaker:
            hint_once("On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0)
            profile, binary_labels = self.rand_permute_speaker(profile, binary_labels)
        if self.attractor is None:
            # 2a. Decoder (baiscally a predction layer after encoder_out)
            pred = self.decoder(encoder_out, encoder_out_lens)
        else:
            # 2b. Encoder Decoder Attractors
            # Shuffle the chronological order of encoder_out, then calculate attractor
            encoder_out_shuffled = encoder_out.clone()
            for i in range(len(encoder_out_lens)):
                encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
                    i, torch.randperm(encoder_out_lens[i]), :
                ]
            attractor, att_prob = self.attractor(
                encoder_out_shuffled,
                encoder_out_lens,
                to_device(
                    self,
                    torch.zeros(
                        encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
                    ),
                ),
            )
            # Remove the final attractor which does not correspond to a speaker
            # Then multiply the attractors and encoder_out
            pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
        # 3. Aggregate time-domain labels
        # 0a. Aggregate time-domain labels to match forward outputs
        if self.label_aggregator is not None:
            spk_labels, spk_labels_lengths = self.label_aggregator(
                spk_labels, spk_labels_lengths
            binary_labels, binary_labels_lengths = self.label_aggregator(
                binary_labels, binary_labels_lengths
            )
        # 0b. augment profiles
        if self.profileaug is not None and self.training:
            speech, profile, binary_labels = self.profileaug(
                speech, speech_lengths,
                profile, profile_lengths,
                binary_labels, binary_labels_lengths
            )
        # 1. Calculate power-set encoding (PSE) labels
        pad_bin_labels = F.pad(binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0)
        raw_pse_labels = torch.sum(pad_bin_labels * self.power_weight, dim=2, keepdim=True)
        pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
        # 2. Network forward
        pred, inter_outputs = self.prediction_forward(
            speech, speech_lengths,
            profile, profile_lengths,
            return_inter_outputs=True
        )
        (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
        # If encoder uses conv* as input_layer (i.e., subsampling),
        # the sequence length of 'pred' might be slighly less than the
        # the sequence length of 'pred' might be slightly less than the
        # length of 'spk_labels'. Here we force them to be equal.
        length_diff_tolerance = 2
        length_diff = spk_labels.shape[1] - pred.shape[1]
        if length_diff > 0 and length_diff <= length_diff_tolerance:
            spk_labels = spk_labels[:, 0 : pred.shape[1], :]
        length_diff = abs(pse_labels.shape[1] - pred.shape[1])
        if length_diff <= length_diff_tolerance:
            min_len = min(pred.shape[1], pse_labels.shape[1])
            pse_labels = pse_labels[:, :min_len]
            pred = pred[:, :min_len]
            cd_score = cd_score[:, :min_len]
            ci_score = ci_score[:, :min_len]
        if self.attractor is None:
            loss_pit, loss_att = None, None
            loss, perm_idx, perm_list, label_perm = self.pit_loss(
                pred, spk_labels, encoder_out_lens
            )
        else:
            loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
                pred, spk_labels, encoder_out_lens
            )
            loss_att = self.attractor_loss(att_prob, spk_labels)
            loss = loss_pit + self.attractor_weight * loss_att
        loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
        loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
        regularizer_loss = None
        if self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None:
            regularizer_loss = self.calculate_regularizer_loss()
        label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
        loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
                + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
        # if regularizer_loss is not None:
        #     loss = loss + regularizer_loss * self.model_regularizer_weight
        (
            correct,
            num_frames,
@@ -160,7 +239,11 @@
            speaker_miss,
            speaker_falarm,
            speaker_error,
        ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
        ) = self.calc_diarization_error(
            pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding),
            label=F.embedding(pse_labels * (~label_mask), self.pse_embedding),
            length=binary_labels_lengths
        )
        if speech_scored > 0 and num_frames > 0:
            sad_mr, sad_fr, mi, fa, cf, acc, der = (
@@ -177,8 +260,11 @@
        stats = dict(
            loss=loss.detach(),
            loss_att=loss_att.detach() if loss_att is not None else None,
            loss_pit=loss_pit.detach() if loss_pit is not None else None,
            loss_diar=loss_diar.detach() if loss_diar is not None else None,
            loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
            loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
            loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
            regularizer_loss=regularizer_loss.detach() if regularizer_loss is not None else None,
            sad_mr=sad_mr,
            sad_fr=sad_fr,
            mi=mi,
@@ -186,17 +272,84 @@
            cf=cf,
            acc=acc,
            der=der,
            forward_steps=self.forward_steps,
        )
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def calculate_regularizer_loss(self):
        regularizer_loss = 0.0
        for name, param in self.to_regularize_parameters:
            regularizer_loss = regularizer_loss + torch.norm(param, p=2)
        return regularizer_loss
    def classification_loss(
            self,
            predictions: torch.Tensor,
            labels: torch.Tensor,
            prediction_lengths: torch.Tensor
    ) -> torch.Tensor:
        mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1])
        pad_labels = labels.masked_fill(
            mask.to(predictions.device),
            value=self.ignore_id
        )
        loss = self.criterion_diar(predictions.contiguous(), pad_labels)
        return loss
    def speaker_discrimination_loss(
            self,
            profile: torch.Tensor,
            profile_lengths: torch.Tensor
    ) -> torch.Tensor:
        profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float()  # (B, N, 1)
        mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2))  # (B, N, N)
        mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask))
        eps = 1e-12
        coding_norm = torch.linalg.norm(
            profile * profile_mask + (1 - profile_mask) * eps,
            dim=2, keepdim=True
        ) * profile_mask
        # profile: Batch, N, dim
        cos_theta = F.cosine_similarity(profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps) * mask
        cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
        loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
        return loss
    def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
        mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1])
        padding_labels = pse_labels.masked_fill(
            mask.to(pse_labels.device),
            value=0
        ).to(pse_labels)
        multi_labels = F.embedding(padding_labels, self.pse_embedding)
        return multi_labels
    def internal_score_loss(
            self,
            cd_score: torch.Tensor,
            ci_score: torch.Tensor,
            pse_labels: torch.Tensor,
            pse_labels_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
        ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
        cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
        return ci_loss, cd_loss
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        binary_labels_lengths: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}
@@ -222,7 +375,7 @@
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.encoder is not None:
        if self.encoder is not None and self.inputs_type == "raw":
            speech, speech_lengths = self.encode(speech, speech_lengths)
            speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1])
            speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float()
@@ -249,7 +402,7 @@
            speaker_encoder_outputs: torch.Tensor,
            seq_len: torch.Tensor = None,
            spk_len: torch.Tensor = None,
    ) -> torch.Tensor:
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1]
        d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2]
        if self.normalize_speech_speaker:
@@ -265,11 +418,11 @@
        if isinstance(self.ci_scorer, AbsEncoder):
            ci_simi = self.ci_scorer(ge_in, ge_len)[0]
            ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
        else:
            ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
        simi = torch.cat([cd_simi, ci_simi], dim=2)
        return simi
        return ci_simi, cd_simi
    def post_net_forward(self, simi, seq_len):
        logits = self.decoder(simi, seq_len)[0]
@@ -282,16 +435,20 @@
            speech_lengths: torch.Tensor,
            profile: torch.Tensor,
            profile_lengths: torch.Tensor,
    ) -> torch.Tensor:
            return_inter_outputs: bool = False,
    ) -> [torch.Tensor, Optional[list]]:
        # speech encoding
        speech, speech_lengths = self.encode_speech(speech, speech_lengths)
        # speaker encoding
        profile, profile_lengths = self.encode_speaker(profile, profile_lengths)
        # calculating similarity
        similarity = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
        ci_simi, cd_simi = self.calc_similarity(speech, profile, speech_lengths, profile_lengths)
        similarity = torch.cat([cd_simi, ci_simi], dim=2)
        # post net forward
        logits = self.post_net_forward(similarity, speech_lengths)
        if return_inter_outputs:
            return logits, [(speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi)]
        return logits
    def encode(
@@ -318,7 +475,8 @@
            # 4. Forward encoder
            # feats: (Batch, Length, Dim)
            # -> encoder_out: (Batch, Length2, Dim)
            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
            encoder_outputs = self.encoder(feats, feats_lengths)
            encoder_out, encoder_out_lens = encoder_outputs[:2]
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -363,9 +521,7 @@
        (batch_size, max_len, num_output) = label.size()
        # mask the padding part
        mask = np.zeros((batch_size, max_len, num_output))
        for i in range(batch_size):
            mask[i, : length[i], :] = 1
        mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy()
        # pred and label have the shape (batch_size, max_len, num_output)
        label_np = label.data.cpu().numpy().astype(int)