From a7814a7bc32aa62ed70631f6478d407fdc0ff488 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期三, 17 五月 2023 17:13:32 +0800
Subject: [PATCH] fix paraformer online last chunk decoding strategy

---
 funasr/models/predictor/cif.py |    5 +++--
 1 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index c59e245..3c363db 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -221,13 +221,14 @@
 
         if cache is not None and "chunk_size" in cache:
             alphas[:, :cache["chunk_size"][0]] = 0.0
-            alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
+            if "is_final" in cache and not cache["is_final"]:
+                alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
         if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
             cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
             cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
             hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
             alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
-        if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
+        if cache is not None and "is_final" in cache and cache["is_final"]:
             tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
             tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
             tail_alphas = torch.tile(tail_alphas, (batch_size, 1))

--
Gitblit v1.9.1