From 9be8a443d74d68f179de88fff13b4e8424579d7b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 三月 2023 18:24:39 +0800
Subject: [PATCH] Merge pull request #207 from alibaba-damo-academy/dev_dzh
---
funasr/tasks/diar.py | 29 ++++++++++++++++++-----------
1 files changed, 18 insertions(+), 11 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 73c51e3..e699dcc 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -23,7 +23,7 @@
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
@@ -122,6 +122,7 @@
fsmn=FsmnEncoder,
conv=ConvEncoder,
resnet34=ResNet34Diar,
+ resnet34_sp_l2reg=ResNet34SpL2RegDiar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
@@ -160,6 +161,7 @@
classes=dict(
dot=DotScorer,
cosine=CosScorer,
+ conv=ConvEncoder,
),
type_check=torch.nn.Module,
default=None,
@@ -571,19 +573,24 @@
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.encoder is not None:
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# speaker encoder
- var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.speaker_encoder is not None:
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# cd scorer
- var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.cd_scorer is not None:
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# ci scorer
- var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.ci_scorer is not None:
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.decoder is not None:
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
--
Gitblit v1.9.1