志浩
2023-03-16 2868fe3df4e92a6ae3e327faf6e57ea492e04124
funasr/tasks/diar.py
@@ -23,7 +23,7 @@
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
@@ -122,6 +122,7 @@
        fsmn=FsmnEncoder,
        conv=ConvEncoder,
        resnet34=ResNet34Diar,
        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
        sanm_chunk_opt=SANMEncoderChunkOpt,
        data2vec_encoder=Data2VecEncoder,
        ecapa_tdnn=ECAPA_TDNN,
@@ -160,6 +161,7 @@
    classes=dict(
        dot=DotScorer,
        cosine=CosScorer,
        conv=ConvEncoder,
    ),
    type_check=torch.nn.Module,
    default=None,
@@ -497,7 +499,7 @@
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
            device: Union[str, torch.device] = "cpu",
    ):
        """Build model from the files.
@@ -552,12 +554,27 @@
                model.load_state_dict(model_dict)
            else:
                model_dict = torch.load(model_file, map_location=device)
        model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
        model.load_state_dict(model_dict)
        if model_name_pth is not None and not os.path.exists(model_name_pth):
            torch.save(model_dict, model_name_pth)
            logging.info("model_file is saved to pth: {}".format(model_name_pth))
        return model, args
    @classmethod
    def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
        from collections import OrderedDict
        new_dict = OrderedDict()
        for key, value in src_dict.items():
            if key in dest_dict:
                new_dict[key] = value
            else:
                logging.info("{} is no longer needed in this model.".format(key))
        for key, value in dest_dict.items():
            if key not in new_dict:
                logging.warning("{} is missed in checkpoint.".format(key))
        return new_dict
    @classmethod
    def convert_tf2torch(
@@ -571,19 +588,24 @@
        var_dict_torch = model.state_dict()
        var_dict_torch_update = dict()
        # speech encoder
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        if model.encoder is not None:
            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # speaker encoder
        var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        if model.speaker_encoder is not None:
            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # cd scorer
        var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        if model.cd_scorer is not None:
            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # ci scorer
        var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        if model.ci_scorer is not None:
            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        if model.decoder is not None:
            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
            var_dict_torch_update.update(var_dict_torch_update_local)
        return var_dict_torch_update