hnluo
2023-09-11 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b
funasr/build_utils/build_model_from_file.py
@@ -6,7 +6,6 @@
import torch
import yaml
from typeguard import check_argument_types
from funasr.build_utils.build_model import build_model
from funasr.models.base_model import FunASRModel
@@ -30,7 +29,6 @@
        device: Device type, "cpu", "cuda", or "cuda:N".
    """
    assert check_argument_types()
    if config_file is None:
        assert model_file is not None, (
            "The argument 'model_file' must be provided "
@@ -72,9 +70,14 @@
            model.load_state_dict(model_dict)
        else:
            model_dict = torch.load(model_file, map_location=device)
    if task_name == "ss":
        model_dict = model_dict['model']
    if task_name == "diar" and mode == "sond":
        model_dict = fileter_model_dict(model_dict, model.state_dict())
    model.load_state_dict(model_dict)
    if task_name == "vad":
        model.encoder.load_state_dict(model_dict)
    else:
        model.load_state_dict(model_dict)
    if model_name_pth is not None and not os.path.exists(model_name_pth):
        torch.save(model_dict, model_name_pth)
        logging.info("model_file is saved to pth: {}".format(model_name_pth))
@@ -87,7 +90,7 @@
        ckpt,
        mode,
):
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
    logging.info("start convert tf model to torch model")
    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
    var_dict_tf = load_tf_dict(ckpt)
@@ -128,7 +131,7 @@
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
    elif "mode" == "sond":
        if model.encoder is not None:
            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
@@ -148,8 +151,33 @@
        if model.decoder is not None:
            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
    elif "mode" == "sv":
        # speech encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # pooling layer
        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # predictor
        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
    return var_dict_torch_update
def fileter_model_dict(src_dict: dict, dest_dict: dict):
    from collections import OrderedDict
@@ -162,4 +190,4 @@
    for key, value in dest_dict.items():
        if key not in new_dict:
            logging.warning("{} is missed in checkpoint.".format(key))
    return new_dict
    return new_dict