zhifu gao
2023-05-18 97a689d65da434345a641a909f13b78e5690c86b
funasr/models/predictor/cif.py
@@ -221,13 +221,14 @@
        if cache is not None and "chunk_size" in cache:
            alphas[:, :cache["chunk_size"][0]] = 0.0
            alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
            if "is_final" in cache and not cache["is_final"]:
                alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
        if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
            cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
            cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
            hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
            alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
        if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
        if cache is not None and "is_final" in cache and cache["is_final"]:
            tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
            tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
            tail_alphas = torch.tile(tail_alphas, (batch_size, 1))