From f3fb62cd4fb1f7eebfd862971cc75dd41478d8c2 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 16 六月 2023 13:15:48 +0800
Subject: [PATCH] update repo
---
funasr/build_utils/build_model_from_file.py | 18 +++++++++++++++---
1 files changed, 15 insertions(+), 3 deletions(-)
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 53eafc1..85cf8b9 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -87,7 +87,7 @@
ckpt,
mode,
):
- assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv"
+ assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
logging.info("start convert tf model to torch model")
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
@@ -148,7 +148,7 @@
if model.decoder is not None:
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
- else:
+ elif "mode" == "sv":
# speech encoder
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
@@ -158,7 +158,19 @@
# decoder
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
-
+ else:
+ # encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # predictor
+ var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # bias_encoder
+ var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
return var_dict_torch_update
--
Gitblit v1.9.1