| | |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | |
| | | fsmn=FsmnEncoder, |
| | | conv=ConvEncoder, |
| | | resnet34=ResNet34Diar, |
| | | resnet34_sp_l2reg=ResNet34SpL2RegDiar, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ecapa_tdnn=ECAPA_TDNN, |
| | |
| | | classes=dict( |
| | | dot=DotScorer, |
| | | cosine=CosScorer, |
| | | conv=ConvEncoder, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default=None, |
| | |
| | | var_dict_torch = model.state_dict() |
| | | var_dict_torch_update = dict() |
| | | # speech encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.encoder is not None: |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # speaker encoder |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.speaker_encoder is not None: |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # cd scorer |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.cd_scorer is not None: |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # ci scorer |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.ci_scorer is not None: |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.decoder is not None: |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |