From 4b30f336ee7e3ca405cfa6ff96d9b3c3e936f767 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 15 六月 2023 15:03:21 +0800
Subject: [PATCH] update repo
---
funasr/build_utils/build_model_from_file.py | 39 +++++++++++++++++++++++++++++++++++++--
1 files changed, 37 insertions(+), 2 deletions(-)
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 5488c10..2eadae4 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -72,6 +72,8 @@
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
+ if task_name == "diar" and mode == "sond":
+ model_dict = fileter_model_dict(model_dict, model.state_dict())
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
@@ -85,7 +87,7 @@
ckpt,
mode,
):
- assert mode == "paraformer" or mode == "uniasr"
+ assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
logging.info("start convert tf model to torch model")
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
@@ -113,7 +115,7 @@
# stride_conv
var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
- else:
+ elif mode == "paraformer":
# encoder
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
@@ -126,5 +128,38 @@
# bias_encoder
var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
+ else:
+ if model.encoder is not None:
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # speaker encoder
+ if model.speaker_encoder is not None:
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # cd scorer
+ if model.cd_scorer is not None:
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # ci scorer
+ if model.ci_scorer is not None:
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ if model.decoder is not None:
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
+
+def fileter_model_dict(src_dict: dict, dest_dict: dict):
+ from collections import OrderedDict
+ new_dict = OrderedDict()
+ for key, value in src_dict.items():
+ if key in dest_dict:
+ new_dict[key] = value
+ else:
+ logging.info("{} is no longer needed in this model.".format(key))
+ for key, value in dest_dict.items():
+ if key not in new_dict:
+ logging.warning("{} is missed in checkpoint.".format(key))
+ return new_dict
\ No newline at end of file
--
Gitblit v1.9.1