志浩
2023-03-09 7907c3df0743da5640eb3be5fa18d37fe6017bbe
modify unit test for speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch
1个文件已修改
5 ■■■■■ 已修改文件
funasr/tasks/diar.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py
@@ -573,18 +573,23 @@
        var_dict_torch = model.state_dict()
        var_dict_torch_update = dict()
        # speech encoder
        if model.encoder is not None:
        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # speaker encoder
        if model.speaker_encoder is not None:
        var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # cd scorer
        if model.cd_scorer is not None:
        var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # ci scorer
        if model.ci_scorer is not None:
        var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)
        # decoder
        if model.decoder is not None:
        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
        var_dict_torch_update.update(var_dict_torch_update_local)