From 8cc5bbf99a59694228aafcbe8712e09b9a4cb26b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 27 二月 2023 17:01:48 +0800
Subject: [PATCH] Merge pull request #159 from alibaba-damo-academy/dev_dzh
---
funasr/tasks/diar.py | 8 ++++++--
1 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index f3212f1..73c51e3 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -24,6 +24,7 @@
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
@@ -123,8 +124,9 @@
resnet34=ResNet34Diar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
+ ecapa_tdnn=ECAPA_TDNN,
),
- type_check=AbsEncoder,
+ type_check=torch.nn.Module,
default="resnet34",
)
speaker_encoder_choices = ClassChoices(
@@ -187,6 +189,8 @@
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
+ # --label_aggregator and --label_aggregator_conf
+ label_aggregator_choices,
# --model and --model_conf
model_choices,
# --encoder and --encoder_conf
@@ -368,7 +372,7 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
- retval = ("speech", "profile", "label")
+ retval = ("speech", "profile", "binary_labels")
else:
# Recognition mode
retval = ("speech", "profile")
--
Gitblit v1.9.1