| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | if task_name == "diar" and mode == "sond": |
| | | model_dict = fileter_model_dict(model_dict, model.state_dict()) |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | |
| | | ckpt, |
| | | mode, |
| | | ): |
| | | assert mode == "paraformer" or mode == "uniasr" |
| | | assert mode == "paraformer" or mode == "uniasr" or mode == "sond" |
| | | logging.info("start convert tf model to torch model") |
| | | from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict |
| | | var_dict_tf = load_tf_dict(ckpt) |
| | |
| | | # stride_conv |
| | | var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | else: |
| | | elif mode == "paraformer": |
| | | # encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | |
| | | # bias_encoder |
| | | var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | else: |
| | | if model.encoder is not None: |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # speaker encoder |
| | | if model.speaker_encoder is not None: |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # cd scorer |
| | | if model.cd_scorer is not None: |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # ci scorer |
| | | if model.ci_scorer is not None: |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | if model.decoder is not None: |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | def fileter_model_dict(src_dict: dict, dest_dict: dict): |
| | | from collections import OrderedDict |
| | | new_dict = OrderedDict() |
| | | for key, value in src_dict.items(): |
| | | if key in dest_dict: |
| | | new_dict[key] = value |
| | | else: |
| | | logging.info("{} is no longer needed in this model.".format(key)) |
| | | for key, value in dest_dict.items(): |
| | | if key not in new_dict: |
| | | logging.warning("{} is missed in checkpoint.".format(key)) |
| | | return new_dict |