维石
2024-06-21 362eed972c885bd3526b75df6e1527925abe06c2
rollback cif_v1 for training bug
1个文件已修改
4 ■■■■ 已修改文件
funasr/models/paraformer/cif_predictor.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py
@@ -80,7 +80,7 @@
                    hidden, alphas, token_num, mask=mask
                )
            acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
            if target_length is None and self.tail_threshold > 0.0:
                token_num_int = torch.max(token_num).type(torch.int32).item()
@@ -245,7 +245,7 @@
                        hidden, alphas, token_num, mask=None
                    )
            acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
            if target_length is None and self.tail_threshold > 0.0:
                token_num_int = torch.max(token_num).type(torch.int32).item()
                acoustic_embeds = acoustic_embeds[:, :token_num_int, :]