| | |
| | |
|
| | | if cache is not None and "chunk_size" in cache:
|
| | | alphas[:, :cache["chunk_size"][0]] = 0.0
|
| | | alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
| | | if "is_final" in cache and not cache["is_final"]:
|
| | | alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
| | | if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
| | | cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
| | | cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
| | | hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
| | | alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
| | | if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
| | | if cache is not None and "is_final" in cache and cache["is_final"]:
|
| | | tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
| | | tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
| | | tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|