hnluo
2023-06-29 c2dee5e3c29eba79e591d9e9caebaef15ea4e56b
funasr/models/e2e_diar_sond.py
@@ -14,9 +14,15 @@
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.base_model import FunASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@@ -30,16 +36,20 @@
class DiarSondModel(FunASRModel):
    """Speaker overlap-aware neural diarization model
    reference: https://arxiv.org/abs/2211.10243
    """
    Author: Speech Lab, Alibaba Group, China
    SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
    https://arxiv.org/abs/2211.10243
    TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
    https://arxiv.org/abs/2303.05397
    """
    def __init__(
        self,
        vocab_size: int,
        frontend: Optional[torch.nn.Module],
        specaug: Optional[torch.nn.Module],
        normalize: Optional[torch.nn.Module],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: torch.nn.Module,
        speaker_encoder: Optional[torch.nn.Module],
        ci_scorer: torch.nn.Module,
@@ -105,7 +115,6 @@
        binary_labels_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
        Args:
            speech: (Batch, samples) or (Batch, frames, input_size)
            speech_lengths: (Batch,) default None for chunk interator,
@@ -342,7 +351,7 @@
        cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
        cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
        if isinstance(self.ci_scorer, torch.nn.Module):
        if isinstance(self.ci_scorer, AbsEncoder):
            ci_simi = self.ci_scorer(ge_in, ge_len)[0]
            ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
        else:
@@ -381,7 +390,6 @@
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch,)
@@ -481,4 +489,4 @@
            speaker_miss,
            speaker_falarm,
            speaker_error,
        )
        )