From 083cf167dbfd07b1f4adaacee2952329e9409620 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 14 三月 2023 14:45:01 +0800
Subject: [PATCH] Merge pull request #227 from songtaoshi/dev_sst

---
 funasr/tasks/diar.py |   35 +++++++++++++++++++++++------------
 1 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 9a43945..e699dcc 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -23,7 +23,8 @@
 from funasr.layers.utterance_mvn import UtteranceMVN
 from funasr.layers.label_aggregation import LabelAggregate
 from funasr.models.ctc import CTC
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
 from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
 from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
 from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
@@ -121,10 +122,12 @@
         fsmn=FsmnEncoder,
         conv=ConvEncoder,
         resnet34=ResNet34Diar,
+        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
+        ecapa_tdnn=ECAPA_TDNN,
     ),
-    type_check=AbsEncoder,
+    type_check=torch.nn.Module,
     default="resnet34",
 )
 speaker_encoder_choices = ClassChoices(
@@ -158,6 +161,7 @@
     classes=dict(
         dot=DotScorer,
         cosine=CosScorer,
+        conv=ConvEncoder,
     ),
     type_check=torch.nn.Module,
     default=None,
@@ -187,6 +191,8 @@
         specaug_choices,
         # --normalize and --normalize_conf
         normalize_choices,
+        # --label_aggregator and --label_aggregator_conf
+        label_aggregator_choices,
         # --model and --model_conf
         model_choices,
         # --encoder and --encoder_conf
@@ -567,19 +573,24 @@
         var_dict_torch = model.state_dict()
         var_dict_torch_update = dict()
         # speech encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # speaker encoder
-        var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # cd scorer
-        var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # ci scorer
-        var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update

--
Gitblit v1.9.1