| | |
| | | var_dict_torch = model.state_dict() |
| | | var_dict_torch_update = dict() |
| | | # speech encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.encoder is not None: |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # speaker encoder |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.speaker_encoder is not None: |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # cd scorer |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.cd_scorer is not None: |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # ci scorer |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.ci_scorer is not None: |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.decoder is not None: |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |