zhong zhuang
2024-06-14 145022c438c8158e6bffdc5ff66fdd28468a0a18
Update cif_predictor.py (#1811)

* Update cif_predictor.py

* modify cif_v1_export

under extreme cases, max_label_len calculated by batch_len misaligns with token_num

* Update cif_predictor.py

torch.cumsum precision degradation, using float64 instead
1个文件已修改
9 ■■■■■ 已修改文件
funasr/models/paraformer/cif_predictor.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py
@@ -504,7 +504,7 @@
    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
    prefix_sum = torch.cumsum(alphas, dim=1)
    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme
    prefix_sum_floor = torch.floor(prefix_sum)
    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -539,7 +539,8 @@
    frames = frames - shift_frames + shift_remain_frames - remain_frames
    max_label_len = batch_len.max()
    max_label_len = alphas.sum(dim=-1)
    max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
    frame_fires = torch.zeros(
        batch_size, max_label_len, hidden_size, dtype=dtype, device=device
@@ -670,7 +671,7 @@
    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
    prefix_sum = torch.cumsum(alphas, dim=1)
    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme
    prefix_sum_floor = torch.floor(prefix_sum)
    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -718,7 +719,7 @@
    frames = frames - shift_frames + shift_remain_frames - remain_frames
    max_label_len = batch_len.max()
    max_label_len = torch.round(alphas.sum(-1)).int().max() # torch.round to calculate the max length
    frame_fires = torch.zeros(
        batch_size, max_label_len, hidden_size, dtype=dtype, device=device