游雁
2023-07-05 4e2fe544ae37174a3e09dfcdbbdae5abfe711e53
funasr/models/e2e_diar_sond.py
@@ -12,7 +12,6 @@
import numpy as np
import torch
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -22,7 +21,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@@ -35,9 +34,13 @@
        yield
class DiarSondModel(AbsESPnetModel):
    """Speaker overlap-aware neural diarization model
    reference: https://arxiv.org/abs/2211.10243
class DiarSondModel(FunASRModel):
    """
    Author: Speech Lab, Alibaba Group, China
    SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
    https://arxiv.org/abs/2211.10243
    TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
    https://arxiv.org/abs/2303.05397
    """
    def __init__(
@@ -62,7 +65,6 @@
        inter_score_loss_weight: float = 0.0,
        inputs_type: str = "raw",
    ):
        assert check_argument_types()
        super().__init__()
@@ -111,7 +113,6 @@
        binary_labels_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
        Args:
            speech: (Batch, samples) or (Batch, frames, input_size)
            speech_lengths: (Batch,) default None for chunk interator,
@@ -387,7 +388,6 @@
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch,)
@@ -487,4 +487,4 @@
            speaker_miss,
            speaker_falarm,
            speaker_error,
        )
        )