From ac00b7deee093773ee2f42f2694746dfbbd8163f Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 15 六月 2023 16:53:26 +0800
Subject: [PATCH] update repo

---
 funasr/build_utils/build_model_from_file.py |   54 ++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 51de5b0..53eafc1 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -17,6 +17,7 @@
         model_file: Union[Path, str] = None,
         cmvn_file: Union[Path, str] = None,
         device: str = "cpu",
+        task_name: str = "asr",
         mode: str = "paraformer",
 ):
     """Build model from the files.
@@ -44,6 +45,7 @@
     if cmvn_file is not None:
         args["cmvn_file"] = cmvn_file
     args = argparse.Namespace(**args)
+    args.task_name = task_name
     model = build_model(args)
     if not isinstance(model, FunASRModel):
         raise RuntimeError(
@@ -70,6 +72,8 @@
             model.load_state_dict(model_dict)
         else:
             model_dict = torch.load(model_file, map_location=device)
+    if task_name == "diar" and mode == "sond":
+        model_dict = fileter_model_dict(model_dict, model.state_dict())
     model.load_state_dict(model_dict)
     if model_name_pth is not None and not os.path.exists(model_name_pth):
         torch.save(model_dict, model_name_pth)
@@ -83,7 +87,7 @@
         ckpt,
         mode,
 ):
-    assert mode == "paraformer" or mode == "uniasr"
+    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv"
     logging.info("start convert tf model to torch model")
     from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
     var_dict_tf = load_tf_dict(ckpt)
@@ -111,7 +115,7 @@
         # stride_conv
         var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
-    else:
+    elif mode == "paraformer":
         # encoder
         var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
@@ -124,5 +128,51 @@
         # bias_encoder
         var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
+    elif "mode" == "sond":
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # speaker encoder
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # cd scorer
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # ci scorer
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
+    else:
+        # speech encoder
+        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # pooling layer
+        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+        # decoder
+        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
+
+        return var_dict_torch_update
 
     return var_dict_torch_update
+
+
+def fileter_model_dict(src_dict: dict, dest_dict: dict):
+    from collections import OrderedDict
+    new_dict = OrderedDict()
+    for key, value in src_dict.items():
+        if key in dest_dict:
+            new_dict[key] = value
+        else:
+            logging.info("{} is no longer needed in this model.".format(key))
+    for key, value in dest_dict.items():
+        if key not in new_dict:
+            logging.warning("{} is missed in checkpoint.".format(key))
+    return new_dict

--
Gitblit v1.9.1