zhifu gao
2023-02-07 2e336050db6b2c634f9ae5f7c4ef6710fd641822
Merge pull request #72 from alibaba-damo-academy/dev_lzr

fix bug in predictor tail_process_fn
1个文件已修改
5 ■■■■■ 已修改文件
funasr/models/predictor/cif.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/predictor/cif.py
@@ -208,7 +208,8 @@
            mask_2 = torch.cat([ones_t, mask], dim=1)
            mask = mask_2 - mask_1
            tail_threshold = mask * tail_threshold
            alphas = torch.cat([alphas, tail_threshold], dim=1)
            alphas = torch.cat([alphas, zeros_t], dim=1)
            alphas = torch.add(alphas, tail_threshold)
        else:
            tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
            tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -654,4 +655,4 @@
        predictor_alignments = index_div_bool_zeros_count_tile_out
        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
        return predictor_alignments.detach(), predictor_alignments_length.detach()
        return predictor_alignments.detach(), predictor_alignments_length.detach()