游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/build_utils/build_model_from_file.py
@@ -6,7 +6,6 @@
import torch
import yaml
from typeguard import check_argument_types
from funasr.build_utils.build_model import build_model
from funasr.models.base_model import FunASRModel
@@ -30,7 +29,6 @@
        device: Device type, "cpu", "cuda", or "cuda:N".
    """
    assert check_argument_types()
    if config_file is None:
        assert model_file is not None, (
            "The argument 'model_file' must be provided "
@@ -72,9 +70,14 @@
            model.load_state_dict(model_dict)
        else:
            model_dict = torch.load(model_file, map_location=device)
    if task_name == "ss":
        model_dict = model_dict['model']
    if task_name == "diar" and mode == "sond":
        model_dict = fileter_model_dict(model_dict, model.state_dict())
    model.load_state_dict(model_dict)
    if task_name == "vad":
        model.encoder.load_state_dict(model_dict)
    else:
        model.load_state_dict(model_dict)
    if model_name_pth is not None and not os.path.exists(model_name_pth):
        torch.save(model_dict, model_name_pth)
        logging.info("model_file is saved to pth: {}".format(model_name_pth))