zhaomingwork
2023-08-21 ca0ada5915b4c71fc4a30db689e0e3646181aebf
funasr/utils/timestamp_tools.py
@@ -19,7 +19,7 @@
        list_fires.append(integrate)
        fire_place = integrate >= threshold
        integrate = torch.where(fire_place,
                                integrate - torch.ones([batch_size], device=alphas.device),
                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                integrate)
    fires = torch.stack(list_fires, 1)
    return fires