| | |
| | | tf2torch_tensor_name_prefix_torch="encoder", |
| | | tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" |
| | | ): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | """ |
| | | |
| | | super(ResNet34Diar, self).__init__( |
| | | input_size, |
| | | use_head_conv=use_head_conv, |
| | |
| | | tf2torch_tensor_name_prefix_torch="encoder", |
| | | tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" |
| | | ): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | super(ResNet34SpL2RegDiar, self).__init__( |
| | | input_size, |
| | | use_head_conv=use_head_conv, |
| | |
| | | name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape |
| | | )) |
| | | else: |
| | | var_dict_torch_update[name] = torch.Tensor(map_dict[name]).type(torch.int64).to("cpu") |
| | | var_dict_torch_update[name] = torch.from_numpy(np.array(map_dict[name])).type(torch.int64).to("cpu") |
| | | logging.info("torch tensor: {}, manually assigning to: {}".format( |
| | | name, map_dict[name] |
| | | )) |
| | | else: |
| | | logging.warning("{} is missed from tf checkpoint".format(name)) |
| | | |
| | | return var_dict_torch_update |
| | | return var_dict_torch_update |