嘉渊
2023-06-15 27fddb4982855d80b850d66f019d20ec19d8d196
funasr/build_utils/build_model_from_file.py
@@ -87,7 +87,7 @@
        ckpt,
        mode,
):
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv"
    logging.info("start convert tf model to torch model")
    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
    var_dict_tf = load_tf_dict(ckpt)
@@ -128,7 +128,7 @@
        # bias_encoder
        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
    else:
    elif "mode" == "sond":
        if model.encoder is not None:
            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
@@ -148,8 +148,21 @@
        if model.decoder is not None:
            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
    else:
        # speech encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # pooling layer
        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update
    return var_dict_torch_update
def fileter_model_dict(src_dict: dict, dest_dict: dict):
    from collections import OrderedDict
@@ -162,4 +175,4 @@
    for key, value in dest_dict.items():
        if key not in new_dict:
            logging.warning("{} is missed in checkpoint.".format(key))
    return new_dict
    return new_dict