Han Zhang
2025-03-18 3c349ac0531b07239f37b81254f8568ab80e3f6a
fix: use converted token_ids for alignment for sensevoice model with timestamp output (#2429)

* fix: use converted token_ids for alignment

BPE doesn't guarantee converted ids (subwords) are revertible. which means `tokens` converted back is not always the same as `token_int`. A easy fix is to directly use the converted ids for alignment. Since they are from the same text, it shouldn't matter.

* fix: handle empty string

to index an empty string is to raise an exception. 这里没有判空。
2个文件已修改
12 ■■■■ 已修改文件
funasr/models/sense_voice/model.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py
@@ -919,14 +919,20 @@
                timestamp = []
                tokens = tokenizer.text2tokens(text)[4:]
                token_back_to_id = tokenizer.tokens2ids(tokens)
                token_ids = []
                for tok_ls in token_back_to_id:
                    if tok_ls: token_ids.extend(tok_ls)
                    else: token_ids.append(124)
                logits_speech = self.ctc.softmax(encoder_out)[i, 4 : encoder_out_lens[i].item(), :]
                pred = logits_speech.argmax(-1).cpu()
                logits_speech[pred == self.blank_id, self.blank_id] = 0
                align = ctc_forced_align(
                    logits_speech.unsqueeze(0).float(),
                    torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
                    torch.Tensor(token_ids).unsqueeze(0).long().to(logits_speech.device),
                    (encoder_out_lens[i] - 4).long(),
                    torch.tensor(len(token_int) - 4).unsqueeze(0).long().to(logits_speech.device),
                    torch.tensor(len(token_ids)).unsqueeze(0).long().to(logits_speech.device),
                    ignore_id=self.ignore_id,
                )
                pred = groupby(align[0, : encoder_out_lens[i]])
funasr/utils/timestamp_tools.py
@@ -160,7 +160,7 @@
        punc_id = int(punc_id) if punc_id is not None else 1
        sentence_end = timestamp[1] if timestamp is not None else sentence_end
        sentence_text_seg = (
            sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
            sentence_text_seg[:-1] if sentence_text_seg and sentence_text_seg[-1] == " " else sentence_text_seg
        )
        if punc_id > 1:
            sentence_text += punc_list[punc_id - 2]