嘉渊
2023-06-15 4b30f336ee7e3ca405cfa6ff96d9b3c3e936f767
funasr/build_utils/build_model_from_file.py
@@ -72,6 +72,8 @@
            model.load_state_dict(model_dict)
        else:
            model_dict = torch.load(model_file, map_location=device)
    if task_name == "diar" and mode == "sond":
        model_dict = fileter_model_dict(model_dict, model.state_dict())
    model.load_state_dict(model_dict)
    if model_name_pth is not None and not os.path.exists(model_name_pth):
        torch.save(model_dict, model_name_pth)
@@ -85,7 +87,7 @@
        ckpt,
        mode,
):
    assert mode == "paraformer" or mode == "uniasr"
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
    logging.info("start convert tf model to torch model")
    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
    var_dict_tf = load_tf_dict(ckpt)
@@ -113,7 +115,7 @@
        # stride_conv
        var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
    elif mode == "paraformer":
        # encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
@@ -126,5 +128,38 @@
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
        if model.encoder is not None:
            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # speaker encoder
        if model.speaker_encoder is not None:
            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # cd scorer
        if model.cd_scorer is not None:
            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # ci scorer
        if model.ci_scorer is not None:
            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        if model.decoder is not None:
            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
    return var_dict_torch_update
def fileter_model_dict(src_dict: dict, dest_dict: dict):
    from collections import OrderedDict
    new_dict = OrderedDict()
    for key, value in src_dict.items():
        if key in dest_dict:
            new_dict[key] = value
        else:
            logging.info("{} is no longer needed in this model.".format(key))
    for key, value in dest_dict.items():
        if key not in new_dict:
            logging.warning("{} is missed in checkpoint.".format(key))
    return new_dict