| | |
| | | ckpt, |
| | | mode, |
| | | ): |
| | | assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" |
| | | assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp" |
| | | logging.info("start convert tf model to torch model") |
| | | from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict |
| | | var_dict_tf = load_tf_dict(ckpt) |
| | |
| | | if model.decoder is not None: |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | else: |
| | | elif "mode" == "sv": |
| | | # speech encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | else: |
| | | # encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # predictor |
| | | var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # bias_encoder |
| | | var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | return var_dict_torch_update |
| | | |
| | | return var_dict_torch_update |