speech_asr
2023-04-11 dfa356a10c698e4e0548ab2d05ae31ab142bd4aa
funasr/models/e2e_diar_sond.py
@@ -14,14 +14,8 @@
from torch.nn import functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.base_model import FunASRModel
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@@ -43,9 +37,9 @@
    def __init__(
        self,
        vocab_size: int,
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        frontend: Optional[torch.nn.Module],
        specaug: Optional[torch.nn.Module],
        normalize: Optional[torch.nn.Module],
        encoder: torch.nn.Module,
        speaker_encoder: Optional[torch.nn.Module],
        ci_scorer: torch.nn.Module,
@@ -348,7 +342,7 @@
        cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
        cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
        if isinstance(self.ci_scorer, AbsEncoder):
        if isinstance(self.ci_scorer, torch.nn.Module):
            ci_simi = self.ci_scorer(ge_in, ge_len)[0]
            ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
        else: