| | |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | |
| | | fsmn=FsmnEncoder, |
| | | conv=ConvEncoder, |
| | | resnet34=ResNet34Diar, |
| | | resnet34_sp_l2reg=ResNet34SpL2RegDiar, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ecapa_tdnn=ECAPA_TDNN, |
| | |
| | | classes=dict( |
| | | dot=DotScorer, |
| | | cosine=CosScorer, |
| | | conv=ConvEncoder, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default=None, |
| | |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | device: Union[str, torch.device] = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model_dict = cls.fileter_model_dict(model_dict, model.state_dict()) |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | | logging.info("model_file is saved to pth: {}".format(model_name_pth)) |
| | | |
| | | return model, args |
| | | |
| | | @classmethod |
| | | def fileter_model_dict(cls, src_dict: dict, dest_dict: dict): |
| | | from collections import OrderedDict |
| | | new_dict = OrderedDict() |
| | | for key, value in src_dict.items(): |
| | | if key in dest_dict: |
| | | new_dict[key] = value |
| | | else: |
| | | logging.info("{} is no longer needed in this model.".format(key)) |
| | | for key, value in dest_dict.items(): |
| | | if key not in new_dict: |
| | | logging.warning("{} is missed in checkpoint.".format(key)) |
| | | return new_dict |
| | | |
| | | @classmethod |
| | | def convert_tf2torch( |
| | |
| | | var_dict_torch = model.state_dict() |
| | | var_dict_torch_update = dict() |
| | | # speech encoder |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.encoder is not None: |
| | | var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # speaker encoder |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.speaker_encoder is not None: |
| | | var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # cd scorer |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.cd_scorer is not None: |
| | | var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # ci scorer |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.ci_scorer is not None: |
| | | var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | # decoder |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | if model.decoder is not None: |
| | | var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |