志浩
2023-03-09 6e5c6f33fdce8acf631e1b18e9bc0fb82c3495ab
funasr/tasks/diar.py
@@ -23,7 +23,8 @@
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
@@ -121,10 +122,12 @@
        fsmn=FsmnEncoder,
        conv=ConvEncoder,
        resnet34=ResNet34Diar,
        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        ecapa_tdnn=ECAPA_TDNN,
    ),
    type_check=AbsEncoder,
    type_check=torch.nn.Module,
    default="resnet34",
)
speaker_encoder_choices = ClassChoices(
@@ -158,6 +161,7 @@
    classes=dict(
        dot=DotScorer,
        cosine=CosScorer,
        conv=ConvEncoder,
    ),
    type_check=torch.nn.Module,
    default=None,
@@ -187,6 +191,8 @@
        specaug_choices,
        # --normalize and --normalize_conf
        normalize_choices,
        # --label_aggregator and --label_aggregator_conf
        label_aggregator_choices,
        # --model and --model_conf
        model_choices,
        # --encoder and --encoder_conf