hnluo
2023-06-05 594b79f59e7eefa6955c729f6264c8c99d1d9571
funasr/bin/diar_infer.py
@@ -235,8 +235,11 @@
                new_seq.append(x)
            else:
                idx_list = np.where(seq < 2 ** vec_dim)[0]
                idx = np.abs(idx_list - i).argmin()
                new_seq.append(seq[idx_list[idx]])
                if len(idx_list) > 0:
                    idx = np.abs(idx_list - i).argmin()
                    new_seq.append(seq[idx_list[idx]])
                else:
                    new_seq.append(0)
        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):