From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/models/e2e_diar_sond.py | 114 +++++++++++++++++++++++++++++++++++++++++++++++----------
1 files changed, 94 insertions(+), 20 deletions(-)
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index dc7135f..aa6294a 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -1,7 +1,8 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
-
+import logging
+import random
from contextlib import contextmanager
from distutils.version import LooseVersion
from itertools import permutations
@@ -12,13 +13,20 @@
import numpy as np
import torch
from torch.nn import functional as F
-from typeguard import check_argument_types
+from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
-from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.abs_profileaug import AbsProfileAug
+from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
+from funasr.utils.hinter import hint_once
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -37,9 +45,10 @@
def __init__(
self,
vocab_size: int,
- frontend: Optional[torch.nn.Module],
- specaug: Optional[torch.nn.Module],
- normalize: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ profileaug: Optional[AbsProfileAug],
+ normalize: Optional[AbsNormalize],
encoder: torch.nn.Module,
speaker_encoder: Optional[torch.nn.Module],
ci_scorer: torch.nn.Module,
@@ -55,8 +64,10 @@
speaker_discrimination_loss_weight: float = 1.0,
inter_score_loss_weight: float = 0.0,
inputs_type: str = "raw",
+ model_regularizer_weight: float = 0.0,
+ freeze_encoder: bool = False,
+ onfly_shuffle_speaker: bool = True,
):
- assert check_argument_types()
super().__init__()
@@ -67,12 +78,16 @@
self.normalize = normalize
self.frontend = frontend
self.specaug = specaug
+ self.profileaug = profileaug
self.label_aggregator = label_aggregator
self.decoder = decoder
self.token_list = token_list
self.max_spk_num = max_spk_num
self.normalize_speech_speaker = normalize_speech_speaker
self.ignore_id = ignore_id
+ self.model_regularizer_weight = model_regularizer_weight
+ self.freeze_encoder = freeze_encoder
+ self.onfly_shuffle_speaker = onfly_shuffle_speaker
self.criterion_diar = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
@@ -87,13 +102,44 @@
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0
self.inputs_type = inputs_type
+ self.to_regularize_parameters = None
+
+ def get_regularize_parameters(self):
+ to_regularize_parameters, normal_parameters = [], []
+ for name, param in self.named_parameters():
+ if ("encoder" in name and "weight" in name and "bn" not in name and
+ ("conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name)
+ ):
+ to_regularize_parameters.append((name, param))
+ else:
+ normal_parameters.append((name, param))
+ self.to_regularize_parameters = to_regularize_parameters
+ return to_regularize_parameters, normal_parameters
def generate_pse_embedding(self):
- embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
+ embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32)
for idx, pse_label in enumerate(self.token_list):
- emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float)
+ emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32)
embedding[idx] = emb
return torch.from_numpy(embedding)
+
+ def rand_permute_speaker(self, raw_profile, raw_binary_labels):
+ """
+ raw_profile: B, N, D
+ raw_binary_labels: B, T, N
+ """
+ assert raw_profile.shape[1] == raw_binary_labels.shape[2], \
+ "Num profile: {}, Num label: {}".format(raw_profile.shape[1], raw_binary_labels.shape[-1])
+ profile = torch.clone(raw_profile)
+ binary_labels = torch.clone(raw_binary_labels)
+ bsz, num_spk = profile.shape[0], profile.shape[1]
+ for i in range(bsz):
+ idx = list(range(num_spk))
+ random.shuffle(idx)
+ profile[i] = profile[i][idx, :]
+ binary_labels[i] = binary_labels[i][:, idx]
+
+ return profile, binary_labels
def forward(
self,
@@ -120,28 +166,44 @@
"""
assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
+ if self.freeze_encoder:
+ hint_once("Freeze encoder", "freeze_encoder", rank=0)
+ self.encoder.eval()
self.forward_steps = self.forward_steps + 1
if self.pse_embedding.device != speech.device:
self.pse_embedding = self.pse_embedding.to(speech.device)
self.power_weight = self.power_weight.to(speech.device)
self.int_token_arr = self.int_token_arr.to(speech.device)
- # 1. Network forward
+ if self.onfly_shuffle_speaker:
+ hint_once("On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0)
+ profile, binary_labels = self.rand_permute_speaker(profile, binary_labels)
+
+ # 0a. Aggregate time-domain labels to match forward outputs
+ if self.label_aggregator is not None:
+ binary_labels, binary_labels_lengths = self.label_aggregator(
+ binary_labels, binary_labels_lengths
+ )
+ # 0b. augment profiles
+ if self.profileaug is not None and self.training:
+ speech, profile, binary_labels = self.profileaug(
+ speech, speech_lengths,
+ profile, profile_lengths,
+ binary_labels, binary_labels_lengths
+ )
+
+ # 1. Calculate power-set encoding (PSE) labels
+ pad_bin_labels = F.pad(binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0)
+ raw_pse_labels = torch.sum(pad_bin_labels * self.power_weight, dim=2, keepdim=True)
+ pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
+
+ # 2. Network forward
pred, inter_outputs = self.prediction_forward(
speech, speech_lengths,
profile, profile_lengths,
return_inter_outputs=True
)
(speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = inter_outputs
-
- # 2. Aggregate time-domain labels to match forward outputs
- if self.label_aggregator is not None:
- binary_labels, binary_labels_lengths = self.label_aggregator(
- binary_labels, binary_labels_lengths
- )
- # 2. Calculate power-set encoding (PSE) labels
- raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
- pse_labels = torch.argmax((raw_pse_labels.int() == self.int_token_arr).float(), dim=2)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slightly less than the
@@ -158,9 +220,14 @@
loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
+ regularizer_loss = None
+ if self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None:
+ regularizer_loss = self.calculate_regularizer_loss()
label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]).to(pse_labels.device)
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
+ # if regularizer_loss is not None:
+ # loss = loss + regularizer_loss * self.model_regularizer_weight
(
correct,
@@ -197,6 +264,7 @@
loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None,
loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None,
loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None,
+ regularizer_loss=regularizer_loss.detach() if regularizer_loss is not None else None,
sad_mr=sad_mr,
sad_fr=sad_fr,
mi=mi,
@@ -209,6 +277,12 @@
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
+
+ def calculate_regularizer_loss(self):
+ regularizer_loss = 0.0
+ for name, param in self.to_regularize_parameters:
+ regularizer_loss = regularizer_loss + torch.norm(param, p=2)
+ return regularizer_loss
def classification_loss(
self,
@@ -342,7 +416,7 @@
cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
- if isinstance(self.ci_scorer, torch.nn.Module):
+ if isinstance(self.ci_scorer, AbsEncoder):
ci_simi = self.ci_scorer(ge_in, ge_len)[0]
ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
else:
--
Gitblit v1.9.1