kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/specaug/profileaug.py
@@ -2,25 +2,26 @@
import numpy as np
import torch
from torch.nn import functional as F
from funasr.models.specaug.abs_profileaug import AbsProfileAug
import torch.nn as nn
class ProfileAug(AbsProfileAug):
class ProfileAug(nn.Module):
    """
    Implement the augmentation for profiles including:
    - Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
    - Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
    - Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
    """
    def __init__(
            self,
            apply_split_aug: bool = True,
            split_aug_prob: float = 0.05,
            apply_merge_aug: bool = True,
            merge_aug_prob: float = 0.2,
            apply_disturb_aug: bool = True,
            disturb_aug_prob: float = 0.4,
            disturb_alpha: float = 0.2,
        self,
        apply_split_aug: bool = True,
        split_aug_prob: float = 0.05,
        apply_merge_aug: bool = True,
        merge_aug_prob: float = 0.2,
        apply_disturb_aug: bool = True,
        disturb_aug_prob: float = 0.4,
        disturb_alpha: float = 0.2,
    ) -> None:
        super().__init__()
        self.apply_split_aug = apply_split_aug
@@ -47,8 +48,9 @@
            to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
            disturb_vec = torch.randn((dim,)).to(profile)
            disturb_vec = F.normalize(disturb_vec, dim=-1)
            profile[idx, to_cover_idx] = F.normalize(profile[idx, split_spk_idx] +
                                                     self.disturb_alpha * disturb_vec)
            profile[idx, to_cover_idx] = F.normalize(
                profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec
            )
            mask[idx, split_spk_idx] = 0
            mask[idx, to_cover_idx] = 0
        return profile, binary_labels, mask
@@ -63,15 +65,19 @@
            valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
            if len(valid_spk_idx) == 0:
                continue
            to_merge = torch.randint(len(valid_spk_idx), (2, ))
            to_merge = torch.randint(len(valid_spk_idx), (2,))
            spk_idx_1, spk_idx_2 = valid_spk_idx[to_merge[0]], valid_spk_idx[to_merge[1]]
            # merge profile
            profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
            profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
            profile[idx, spk_idx_2] = 0
            # merge binary labels
            binary_labels[idx, :, spk_idx_1] = binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
            binary_labels[idx, :, spk_idx_1] = (binary_labels[idx, :, spk_idx_1] > 0).to(binary_labels)
            binary_labels[idx, :, spk_idx_1] = (
                binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
            )
            binary_labels[idx, :, spk_idx_1] = (binary_labels[idx, :, spk_idx_1] > 0).to(
                binary_labels
            )
            binary_labels[idx, :, spk_idx_2] = 0
            mask[idx, spk_idx_1] = 0
@@ -93,30 +99,44 @@
            to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
            disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
            alpha = self.disturb_alpha * torch.rand(()).item()
            profile[idx, to_disturb_idx] = ((1 - alpha) * profile[idx, to_disturb_idx]
                                            + alpha * profile[idx, disturb_idx])
            profile[idx, to_disturb_idx] = (1 - alpha) * profile[
                idx, to_disturb_idx
            ] + alpha * profile[idx, disturb_idx]
            profile[idx, to_disturb_idx] = F.normalize(profile[idx, to_disturb_idx], dim=-1)
            mask[idx, to_disturb_idx] = 0
        return profile, binary_labels, mask
    def forward(
            self,
            speech: torch.Tensor, speech_lengths: torch.Tensor = None,
            profile: torch.Tensor = None, profile_lengths: torch.Tensor = None,
            binary_labels: torch.Tensor = None, labels_length: torch.Tensor = None
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        labels_length: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        # copy inputs to avoid inplace-operation
        speech, profile, binary_labels = torch.clone(speech), torch.clone(profile), torch.clone(binary_labels)
        speech, profile, binary_labels = (
            torch.clone(speech),
            torch.clone(profile),
            torch.clone(binary_labels),
        )
        profile = F.normalize(profile, dim=-1)
        profile_mask = torch.ones(profile.shape[:2]).to(profile)
        if self.apply_disturb_aug:
            profile, binary_labels, profile_mask = self.disturb_aug(profile, binary_labels, profile_mask)
            profile, binary_labels, profile_mask = self.disturb_aug(
                profile, binary_labels, profile_mask
            )
        if self.apply_split_aug:
            profile, binary_labels, profile_mask = self.split_aug(profile, binary_labels, profile_mask)
            profile, binary_labels, profile_mask = self.split_aug(
                profile, binary_labels, profile_mask
            )
        if self.apply_merge_aug:
            profile, binary_labels, profile_mask = self.merge_aug(profile, binary_labels, profile_mask)
            profile, binary_labels, profile_mask = self.merge_aug(
                profile, binary_labels, profile_mask
            )
        return speech, profile, binary_labels