From e422c6197b5bcada0429986500d8d5ca4ffcb3e4 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 10 五月 2023 19:23:37 +0800
Subject: [PATCH] update repo
---
funasr/tasks/diar.py | 35 +++++++++++++++++++++++++----------
1 files changed, 25 insertions(+), 10 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 6204cb7..6be75a0 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -50,7 +50,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -106,7 +106,7 @@
sond=DiarSondModel,
eend_ola=DiarEENDOLAModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="sond",
)
encoder_choices = ClassChoices(
@@ -507,7 +507,7 @@
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
+ device: Union[str, torch.device] = "cpu",
):
"""Build model from the files.
@@ -536,9 +536,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -553,7 +553,7 @@
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
- model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
+ model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
@@ -562,12 +562,27 @@
model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
+ model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args
+
+ @classmethod
+ def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
+ from collections import OrderedDict
+ new_dict = OrderedDict()
+ for key, value in src_dict.items():
+ if key in dest_dict:
+ new_dict[key] = value
+ else:
+ logging.info("{} is no longer needed in this model.".format(key))
+ for key, value in dest_dict.items():
+ if key not in new_dict:
+ logging.warning("{} is missed in checkpoint.".format(key))
+ return new_dict
@classmethod
def convert_tf2torch(
@@ -787,10 +802,10 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
- retval = ("speech", "profile", "binary_labels")
+ retval = ("speech", )
else:
# Recognition mode
- retval = ("speech")
+ retval = ("speech", )
return retval
@classmethod
@@ -879,9 +894,9 @@
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
--
Gitblit v1.9.1