| | |
| | | import torch |
| | | from torch.nn import functional as F |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from typing import Tuple, Optional |
| | | from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from collections import OrderedDict |
| | | import logging |
| | | import numpy as np |
| | |
| | | return xs_pad, ilens |
| | | |
| | | |
| | | class ResNet34(torch.nn.Module): |
| | | class ResNet34(AbsEncoder): |
| | | def __init__( |
| | | self, |
| | | input_size, |
| | |
| | | tf2torch_tensor_name_prefix_torch="encoder", |
| | | tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" |
| | | ): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | """ |
| | | |
| | | super(ResNet34Diar, self).__init__( |
| | | input_size, |
| | | use_head_conv=use_head_conv, |
| | |
| | | tf2torch_tensor_name_prefix_torch="encoder", |
| | | tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder" |
| | | ): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | super(ResNet34SpL2RegDiar, self).__init__( |
| | | input_size, |
| | | use_head_conv=use_head_conv, |
| | |
| | | else: |
| | | logging.warning("{} is missed from tf checkpoint".format(name)) |
| | | |
| | | return var_dict_torch_update |
| | | return var_dict_torch_update |