| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import logging |
| | | import random |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from itertools import permutations |
| | |
| | | import numpy as np |
| | | import torch |
| | | from torch.nn import functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import to_device |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.abs_profileaug import AbsProfileAug |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy |
| | | from funasr.utils.misc import int2vec |
| | | from funasr.utils.hinter import hint_once |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | |
| | | |
| | | class DiarSondModel(FunASRModel): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """Speaker overlap-aware neural diarization model |
| | | reference: https://arxiv.org/abs/2211.10243 |
| | | """ |
| | | |
| | | def __init__( |
| | |
| | | vocab_size: int, |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | profileaug: Optional[AbsProfileAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: torch.nn.Module, |
| | | speaker_encoder: Optional[torch.nn.Module], |
| | |
| | | speaker_discrimination_loss_weight: float = 1.0, |
| | | inter_score_loss_weight: float = 0.0, |
| | | inputs_type: str = "raw", |
| | | model_regularizer_weight: float = 0.0, |
| | | freeze_encoder: bool = False, |
| | | onfly_shuffle_speaker: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | super().__init__() |
| | | |
| | |
| | | self.normalize = normalize |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.profileaug = profileaug |
| | | self.label_aggregator = label_aggregator |
| | | self.decoder = decoder |
| | | self.token_list = token_list |
| | | self.max_spk_num = max_spk_num |
| | | self.normalize_speech_speaker = normalize_speech_speaker |
| | | self.ignore_id = ignore_id |
| | | self.model_regularizer_weight = model_regularizer_weight |
| | | self.freeze_encoder = freeze_encoder |
| | | self.onfly_shuffle_speaker = onfly_shuffle_speaker |
| | | self.criterion_diar = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | self.forward_steps = 0 |
| | | self.inputs_type = inputs_type |
| | | self.to_regularize_parameters = None |
| | | |
| | | def get_regularize_parameters(self): |
| | | to_regularize_parameters, normal_parameters = [], [] |
| | | for name, param in self.named_parameters(): |
| | | if ("encoder" in name and "weight" in name and "bn" not in name and |
| | | ("conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name) |
| | | ): |
| | | to_regularize_parameters.append((name, param)) |
| | | else: |
| | | normal_parameters.append((name, param)) |
| | | self.to_regularize_parameters = to_regularize_parameters |
| | | return to_regularize_parameters, normal_parameters |
| | | |
| | | def generate_pse_embedding(self): |
| | | embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) |
| | | embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32) |
| | | for idx, pse_label in enumerate(self.token_list): |
| | | emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float) |
| | | emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32) |
| | | embedding[idx] = emb |
| | | return torch.from_numpy(embedding) |
| | | |
| | | def rand_permute_speaker(self, raw_profile, raw_binary_labels): |
| | | """ |
| | | raw_profile: B, N, D |
| | | raw_binary_labels: B, T, N |
| | | """ |
| | | assert raw_profile.shape[1] == raw_binary_labels.shape[2], \ |
| | | "Num profile: {}, Num label: {}".format(raw_profile.shape[1], raw_binary_labels.shape[-1]) |
| | | profile = torch.clone(raw_profile) |
| | | binary_labels = torch.clone(raw_binary_labels) |
| | | bsz, num_spk = profile.shape[0], profile.shape[1] |
| | | for i in range(bsz): |
| | | idx = list(range(num_spk)) |
| | | random.shuffle(idx) |
| | | profile[i] = profile[i][idx, :] |
| | | binary_labels[i] = binary_labels[i][:, idx] |
| | | |
| | | return profile, binary_labels |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | binary_labels_lengths: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss |
| | | |
| | | Args: |
| | | speech: (Batch, samples) or (Batch, frames, input_size) |
| | | speech_lengths: (Batch,) default None for chunk interator, |
| | |
| | | """ |
| | | assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | batch_size = speech.shape[0] |
| | | if self.freeze_encoder: |
| | | hint_once("Freeze encoder", "freeze_encoder", rank=0) |
| | | self.encoder.eval() |
| | | self.forward_steps = self.forward_steps + 1 |
| | | if self.pse_embedding.device != speech.device: |
| | | self.pse_embedding = self.pse_embedding.to(speech.device) |
| | | self.power_weight = self.power_weight.to(speech.device) |
| | | self.int_token_arr = self.int_token_arr.to(speech.device) |
| | | |
| | | # 1. Network forward |
| | | if self.onfly_shuffle_speaker: |
| | | hint_once("On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0) |
| | | profile, binary_labels = self.rand_permute_speaker(profile, binary_labels) |
| | | |
| | | # 0a. Aggregate time-domain labels to match forward outputs |
| | | if self.label_aggregator is not None: |
| | | binary_labels, binary_labels_lengths = self.label_aggregator( |
| | | binary_labels, binary_labels_lengths |
| | | ) |
| | | # 0b. augment profiles |
| | | if self.profileaug is not None and self.training: |
| | | speech, profile, binary_labels = self.profileaug( |
| | | speech, speech_lengths, |
| | | profile, profile_lengths, |
| | | binary_labels, binary_labels_lengths |
| | | ) |
| | | |
| | | # 1. Calculate power-set encoding (PSE) labels |
| | | pad_bin_labels = F.pad(binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0) |
| | | raw_pse_labels = torch.sum(pad_bin_labels * self.power_weight, dim=2, keepdim=True) |
| | | pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2) |
| | | |
| | | # 2. Network forward |
| | | pred, inter_outputs = self.prediction_forward( |
| | | speech, speech_lengths, |
| | | profile, profile_lengths, |
| | | return_inter_outputs=True |
| | | ) |
| | | (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs |
| | | |
| | | # 2. Aggregate time-domain labels to match forward outputs |
| | | if self.label_aggregator is not None: |
| | | binary_labels, binary_labels_lengths = self.label_aggregator( |
| | | binary_labels, binary_labels_lengths |
| | | ) |
| | | # 2. Calculate power-set encoding (PSE) labels |
| | | raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True) |
| | | pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2) |
| | | |
| | | # If encoder uses conv* as input_layer (i.e., subsampling), |
| | | # the sequence length of 'pred' might be slightly less than the |
| | |
| | | loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) |
| | | loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) |
| | | loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths) |
| | | regularizer_loss = None |
| | | if self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None: |
| | | regularizer_loss = self.calculate_regularizer_loss() |
| | | label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device) |
| | | loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis |
| | | + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd)) |
| | | # if regularizer_loss is not None: |
| | | # loss = loss + regularizer_loss * self.model_regularizer_weight |
| | | |
| | | ( |
| | | correct, |
| | |
| | | loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None, |
| | | loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None, |
| | | loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None, |
| | | regularizer_loss=regularizer_loss.detach() if regularizer_loss is not None else None, |
| | | sad_mr=sad_mr, |
| | | sad_fr=sad_fr, |
| | | mi=mi, |
| | |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def calculate_regularizer_loss(self): |
| | | regularizer_loss = 0.0 |
| | | for name, param in self.to_regularize_parameters: |
| | | regularizer_loss = regularizer_loss + torch.norm(param, p=2) |
| | | return regularizer_loss |
| | | |
| | | def classification_loss( |
| | | self, |
| | |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch,) |
| | |
| | | speaker_miss, |
| | | speaker_falarm, |
| | | speaker_error, |
| | | ) |
| | | ) |