From 85566dbb2a4d8d09f80a2460ae8c3e0a9d908cf4 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期一, 13 三月 2023 20:00:19 +0800
Subject: [PATCH] fix the output of vad_results is null
---
funasr/tasks/diar.py | 25 +++++++++++++++----------
1 files changed, 15 insertions(+), 10 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 45e3d7a..e699dcc 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -573,19 +573,24 @@
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.encoder is not None:
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# speaker encoder
- var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.speaker_encoder is not None:
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# cd scorer
- var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.cd_scorer is not None:
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# ci scorer
- var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.ci_scorer is not None:
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.decoder is not None:
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
--
Gitblit v1.9.1