From 7907c3df0743da5640eb3be5fa18d37fe6017bbe Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 09 三月 2023 17:07:51 +0800
Subject: [PATCH] modify unit test for speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch

---
 funasr/tasks/diar.py |   25 +++++++++++++++----------
 1 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 45e3d7a..e699dcc 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -573,19 +573,24 @@
         var_dict_torch = model.state_dict()
         var_dict_torch_update = dict()
         # speech encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # speaker encoder
-        var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # cd scorer
-        var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # ci scorer
-        var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update

--
Gitblit v1.9.1