From 5da92c1fa931a0607d880f7d6485d7ff53d928ec Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期三, 15 二月 2023 11:51:27 +0800
Subject: [PATCH] add training related code for sond
---
funasr/models/e2e_diar_sond.py | 168 +++++++++++++++++++++++++++++++++++++++-----------------
1 files changed, 117 insertions(+), 51 deletions(-)
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index d29ffe5..7b6e955 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -7,7 +7,7 @@
from itertools import permutations
from typing import Dict
from typing import Optional
-from typing import Tuple
+from typing import Tuple, List
import numpy as np
import torch
@@ -23,6 +23,8 @@
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
+from funasr.utils.misc import int2vec
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -54,7 +56,10 @@
length_normalized_loss: bool = False,
max_spk_num: int = 16,
label_aggregator: Optional[torch.nn.Module] = None,
- normlize_speech_speaker: bool = False,
+ normalize_speech_speaker: bool = False,
+ ignore_id: int = -1,
+ speaker_discrimination_loss_weight: float = 1.0,
+ inter_score_loss_weight: float = 0.0
):
assert check_argument_types()
@@ -71,7 +76,25 @@
self.decoder = decoder
self.token_list = token_list
self.max_spk_num = max_spk_num
- self.normalize_speech_speaker = normlize_speech_speaker
+ self.normalize_speech_speaker = normalize_speech_speaker
+ self.ignore_id = ignore_id
+ self.criterion_diar = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
+ self.pse_embedding = self.generate_pse_embedding()
+ self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
+ self.inter_score_loss_weight = inter_score_loss_weight
+
+ def generate_pse_embedding(self):
+ embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
+ for idx, pse_label in enumerate(self.token_list):
+ emb = int2vec(pse_label, vec_dim=self.max_spk_num, dtype=np.float)
+ embedding[idx] = emb
+ return torch.from_numpy(embedding)
def forward(
self,
@@ -85,7 +108,7 @@
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
Args:
- speech: (Batch, samples)
+ speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
because the chunk-iterator does not
have the speech_lengths returned.
@@ -93,63 +116,42 @@
espnet2/iterators/chunk_iter_factory.py
profile: (Batch, N_spk, dim)
profile_lengths: (Batch,)
- spk_labels: (Batch, )
+ spk_labels: (Batch, frames, input_size)
+ spk_labels_lengths: (Batch,)
"""
assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
batch_size = speech.shape[0]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ # 1. Network forward
+ pred, inter_outputs = self.prediction_forward(
+ speech, speech_lengths,
+ profile, profile_lengths,
+ return_inter_outputs=True
+ )
+ (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
- if self.attractor is None:
- # 2a. Decoder (baiscally a predction layer after encoder_out)
- pred = self.decoder(encoder_out, encoder_out_lens)
- else:
- # 2b. Encoder Decoder Attractors
- # Shuffle the chronological order of encoder_out, then calculate attractor
- encoder_out_shuffled = encoder_out.clone()
- for i in range(len(encoder_out_lens)):
- encoder_out_shuffled[i, : encoder_out_lens[i], :] = encoder_out[
- i, torch.randperm(encoder_out_lens[i]), :
- ]
- attractor, att_prob = self.attractor(
- encoder_out_shuffled,
- encoder_out_lens,
- to_device(
- self,
- torch.zeros(
- encoder_out.size(0), spk_labels.size(2) + 1, encoder_out.size(2)
- ),
- ),
- )
- # Remove the final attractor which does not correspond to a speaker
- # Then multiply the attractors and encoder_out
- pred = torch.bmm(encoder_out, attractor[:, :-1, :].permute(0, 2, 1))
- # 3. Aggregate time-domain labels
+ # 2. Aggregate time-domain labels to match forward outputs
if self.label_aggregator is not None:
spk_labels, spk_labels_lengths = self.label_aggregator(
- spk_labels, spk_labels_lengths
+ spk_labels.unsqueeze(2), spk_labels_lengths
)
+ spk_labels = spk_labels.squeeze(2)
# If encoder uses conv* as input_layer (i.e., subsampling),
- # the sequence length of 'pred' might be slighly less than the
+ # the sequence length of 'pred' might be slightly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
length_diff = spk_labels.shape[1] - pred.shape[1]
- if length_diff > 0 and length_diff <= length_diff_tolerance:
- spk_labels = spk_labels[:, 0 : pred.shape[1], :]
+ if 0 < length_diff <= length_diff_tolerance:
+ spk_labels = spk_labels[:, 0: pred.shape[1], :]
- if self.attractor is None:
- loss_pit, loss_att = None, None
- loss, perm_idx, perm_list, label_perm = self.pit_loss(
- pred, spk_labels, encoder_out_lens
- )
- else:
- loss_pit, perm_idx, perm_list, label_perm = self.pit_loss(
- pred, spk_labels, encoder_out_lens
- )
- loss_att = self.attractor_loss(att_prob, spk_labels)
- loss = loss_pit + self.attractor_weight * loss_att
+ loss_diar = self.classification_loss(pred, spk_labels, spk_labels_lengths)
+ loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
+ loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, spk_labels, spk_labels_lengths)
+ label_mask = make_pad_mask(spk_labels_lengths, maxlen=spk_labels.shape[1])
+ loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+ + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
+
(
correct,
num_frames,
@@ -160,7 +162,11 @@
speaker_miss,
speaker_falarm,
speaker_error,
- ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens)
+ ) = self.calc_diarization_error(
+ pred=F.embedding(pred.argmax(dim=2) * label_mask, self.pse_embedding),
+ label=F.embedding(spk_labels * label_mask, self.pse_embedding),
+ length=spk_labels_lengths
+ )
if speech_scored > 0 and num_frames > 0:
sad_mr, sad_fr, mi, fa, cf, acc, der = (
@@ -177,8 +183,10 @@
stats = dict(
loss=loss.detach(),
- loss_att=loss_att.detach() if loss_att is not None else None,
- loss_pit=loss_pit.detach() if loss_pit is not None else None,
+ loss_diar=loss_diar.detach() if loss_diar is not None else None,
+ loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
+ loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
+ loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
@@ -190,6 +198,61 @@
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
+
+ def classification_loss(
+ self,
+ predictions: torch.Tensor,
+ labels: torch.Tensor,
+ prediction_lengths: torch.Tensor
+ ) -> torch.Tensor:
+ pad_labels = labels.masked_fill(
+ make_pad_mask(prediction_lengths, maxlen=labels.shape[1]),
+ value=self.ignore_id
+ )
+ loss = self.criterion_diar(predictions, pad_labels)
+
+ return loss
+
+ def speaker_discrimination_loss(
+ self,
+ profile: torch.Tensor,
+ profile_lengths: torch.Tensor
+ ) -> torch.Tensor:
+ profile_mask = (torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0).float() # (B, N, 1)
+ mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N)
+ mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0))
+
+ eps = 1e-12
+ coding_norm = torch.linalg.norm(
+ profile * profile_mask + (1 - profile_mask) * eps,
+ dim=2, keepdim=True
+ ) * profile_mask
+ cos_theta = F.cosine_similarity(profile, profile, dim=2, eps=eps) * mask
+ cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps)
+ loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum()
+
+ return loss
+
+ def calculate_multi_labels(self, pse_labels, pse_labels_lengths):
+ padding_labels = pse_labels.masked_fill(
+ make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1]),
+ value=0
+ ).to(pse_labels.dtype)
+ multi_labels = F.embedding(padding_labels, self.pse_embedding)
+
+ return multi_labels
+
+ def internal_score_loss(
+ self,
+ cd_score: torch.Tensor,
+ ci_score: torch.Tensor,
+ pse_labels: torch.Tensor,
+ pse_labels_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths)
+ ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths)
+ cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths)
+ return ci_loss, cd_loss
def collect_feats(
self,
@@ -282,7 +345,8 @@
speech_lengths: torch.Tensor,
profile: torch.Tensor,
profile_lengths: torch.Tensor,
- ) -> torch.Tensor:
+ return_inter_outputs: bool = False,
+ ) -> [torch.Tensor, Optional[list]]:
# speech encoding
speech, speech_lengths = self.encode_speech(speech, speech_lengths)
# speaker encoding
@@ -292,6 +356,8 @@
# post net forward
logits = self.post_net_forward(similarity, speech_lengths)
+ if return_inter_outputs:
+ return logits, [(speech, speech_lengths), (profile, profile_lengths), torch.split(similarity, 2)]
return logits
def encode(
--
Gitblit v1.9.1