From 85566dbb2a4d8d09f80a2460ae8c3e0a9d908cf4 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期一, 13 三月 2023 20:00:19 +0800
Subject: [PATCH] fix the output of vad_results is null

---
 funasr/tasks/diar.py |   25 +++++++++++++++----------
 1 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 45e3d7a..e699dcc 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -573,19 +573,24 @@
         var_dict_torch = model.state_dict()
         var_dict_torch_update = dict()
         # speech encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # speaker encoder
-        var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # cd scorer
-        var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # ci scorer
-        var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update

--
Gitblit v1.9.1