From 4b30f336ee7e3ca405cfa6ff96d9b3c3e936f767 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 15 六月 2023 15:03:21 +0800
Subject: [PATCH] update repo

---
 funasr/build_utils/build_model_from_file.py |   39 +++++++++++++++++++++++++++++++++++++--
 1 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 5488c10..2eadae4 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -72,6 +72,8 @@
             model.load_state_dict(model_dict)
         else:
             model_dict = torch.load(model_file, map_location=device)
+    if task_name == "diar" and mode == "sond":
+        model_dict = fileter_model_dict(model_dict, model.state_dict())
     model.load_state_dict(model_dict)
     if model_name_pth is not None and not os.path.exists(model_name_pth):
         torch.save(model_dict, model_name_pth)
@@ -85,7 +87,7 @@
         ckpt,
         mode,
 ):
-    assert mode == "paraformer" or mode == "uniasr"
+    assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
     logging.info("start convert tf model to torch model")
     from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
     var_dict_tf = load_tf_dict(ckpt)
@@ -113,7 +115,7 @@
         # stride_conv
         var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
-    else:
+    elif mode == "paraformer":
         # encoder
         var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
@@ -126,5 +128,38 @@
         # bias_encoder
         var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
+    else:
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # speaker encoder
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # cd scorer
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # ci scorer
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
 
     return var_dict_torch_update
+
+def fileter_model_dict(src_dict: dict, dest_dict: dict):
+    from collections import OrderedDict
+    new_dict = OrderedDict()
+    for key, value in src_dict.items():
+        if key in dest_dict:
+            new_dict[key] = value
+        else:
+            logging.info("{} is no longer needed in this model.".format(key))
+    for key, value in dest_dict.items():
+        if key not in new_dict:
+            logging.warning("{} is missed in checkpoint.".format(key))
+    return new_dict
\ No newline at end of file

--
Gitblit v1.9.1