zhifu gao
2023-02-27 8cc5bbf99a59694228aafcbe8712e09b9a4cb26b
funasr/tasks/diar.py
@@ -24,6 +24,7 @@
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
@@ -123,8 +124,9 @@
        resnet34=ResNet34Diar,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        ecapa_tdnn=ECAPA_TDNN,
    ),
    type_check=AbsEncoder,
    type_check=torch.nn.Module,
    default="resnet34",
)
speaker_encoder_choices = ClassChoices(
@@ -187,6 +189,8 @@
        specaug_choices,
        # --normalize and --normalize_conf
        normalize_choices,
        # --label_aggregator and --label_aggregator_conf
        label_aggregator_choices,
        # --model and --model_conf
        model_choices,
        # --encoder and --encoder_conf
@@ -368,7 +372,7 @@
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        if not inference:
            retval = ("speech", "profile", "label")
            retval = ("speech", "profile", "binary_labels")
        else:
            # Recognition mode
            retval = ("speech", "profile")