雾聪
2023-11-13 c0bfc7c6fe6e1e7f1ead8174a6ea7341706b46d5
funasr/models/predictor/cif.py
@@ -499,7 +499,7 @@
        fire_place = integrate >= threshold
        integrate = torch.where(fire_place,
                                integrate - torch.ones([batch_size], device=alphas.device),
                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                integrate)
    fires = torch.stack(list_fires, 1)