funasr/export/models/predictor/cif.py
@@ -109,7 +109,8 @@ frames = torch.stack(list_frames, 1) list_ls = [] len_labels = torch.round(alphas.sum(-1)).int() max_label_len = len_labels.max() max_label_len = len_labels.max().item() print("type: {}".format(type(max_label_len))) for b in range(batch_size): fire = fires[b, :] l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())