modify unit test for speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch
| | |
| | | tt = xs_pad.shape[2] |
| | | num_chunk = int(math.ceil(tt / pooling_stride)) |
| | | pad = pooling_size // 2 |
| | | features = F.pad(xs_pad, (0, 0, pad, pad), "reflect") |
| | | if len(xs_pad.shape == 4): |
| | | features = F.pad(xs_pad, (0, 0, pad, pad), "reflect") |
| | | else: |
| | | features = F.pad(xs_pad, (pad, pad), "reflect") |
| | | stat_list = [] |
| | | |
| | | for i in range(num_chunk): |
| | | # B x C |
| | | st, ed = i*pooling_stride, i*pooling_stride+pooling_size |
| | | stat = statistic_pooling(features[:, :, st: ed, :], pooling_dim=pooling_dim) |
| | | stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim) |
| | | stat_list.append(stat.unsqueeze(2)) |
| | | |
| | | # B x C x T |