Zhihao Du
2023-05-25 6cc2e585b745685c71c10af80dda3553cc949460
funasr/bin/diar_infer.py
@@ -235,8 +235,11 @@
                new_seq.append(x)
            else:
                idx_list = np.where(seq < 2 ** vec_dim)[0]
                idx = np.abs(idx_list - i).argmin()
                new_seq.append(seq[idx_list[idx]])
                if len(idx_list) > 0:
                    idx = np.abs(idx_list - i).argmin()
                    new_seq.append(seq[idx_list[idx]])
                else:
                    new_seq.append(0)
        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):