雾聪
2023-06-02 fae856e23d45fd27d5fd55fd036e8e3fc7b24915
funasr/bin/diar_infer.py
@@ -235,8 +235,11 @@
                new_seq.append(x)
            else:
                idx_list = np.where(seq < 2 ** vec_dim)[0]
                idx = np.abs(idx_list - i).argmin()
                new_seq.append(seq[idx_list[idx]])
                if len(idx_list) > 0:
                    idx = np.abs(idx_list - i).argmin()
                    new_seq.append(seq[idx_list[idx]])
                else:
                    new_seq.append(0)
        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):