update func cif_wo_hidden
| | |
| | |
|
| | | fire_place = integrate >= threshold
|
| | | integrate = torch.where(fire_place,
|
| | | integrate - torch.ones([batch_size], device=alphas.device),
|
| | | integrate - torch.ones([batch_size], device=alphas.device)*threshold,
|
| | | integrate)
|
| | |
|
| | | fires = torch.stack(list_fires, 1)
|
| | |
| | |
|
| | | fire_place = integrate >= threshold
|
| | | integrate = torch.where(fire_place,
|
| | | integrate - torch.ones([batch_size], device=alphas.device),
|
| | | integrate - torch.ones([batch_size], device=alphas.device)*threshold,
|
| | | integrate)
|
| | |
|
| | | fires = torch.stack(list_fires, 1)
|
| | |
| | | list_fires.append(integrate) |
| | | fire_place = integrate >= threshold |
| | | integrate = torch.where(fire_place, |
| | | integrate - torch.ones([batch_size], device=alphas.device), |
| | | integrate - torch.ones([batch_size], device=alphas.device)*threshold, |
| | | integrate) |
| | | fires = torch.stack(list_fires, 1) |
| | | return fires |