From a7814a7bc32aa62ed70631f6478d407fdc0ff488 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期三, 17 五月 2023 17:13:32 +0800
Subject: [PATCH] fix paraformer online last chunk decoding strategy
---
funasr/bin/asr_infer.py | 17 -----------------
funasr/models/encoder/sanm_encoder.py | 11 +----------
funasr/models/predictor/cif.py | 5 +++--
3 files changed, 4 insertions(+), 29 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index f6c5504..03145f8 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -762,23 +762,6 @@
feats_len = speech_lengths
if feats.shape[1] != 0:
- if cache_en["is_final"]:
- if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
- cache_en["last_chunk"] = True
- else:
- # first chunk
- feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
- feats_len = torch.tensor([feats_chunk1.shape[1]])
- results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
-
- # last chunk
- cache_en["last_chunk"] = True
- feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
- feats_len = torch.tensor([feats_chunk2.shape[1]])
- results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
-
- return [" ".join(results_chunk1 + results_chunk2)]
-
results = self.infer(feats, feats_len, cache)
return results
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index e071e57..da67586 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -355,18 +355,9 @@
def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
- # process last chunk
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- if cache["is_final"]:
- cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
- if not cache["last_chunk"]:
- padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
- overlap_feats = overlap_feats.transpose(1, 2)
- overlap_feats = F.pad(overlap_feats, (0, padding_length))
- overlap_feats = overlap_feats.transpose(1, 2)
- else:
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
return overlap_feats
def forward_chunk(self,
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index c59e245..3c363db 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -221,13 +221,14 @@
if cache is not None and "chunk_size" in cache:
alphas[:, :cache["chunk_size"][0]] = 0.0
- alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
+ if "is_final" in cache and not cache["is_final"]:
+ alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
- if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
+ if cache is not None and "is_final" in cache and cache["is_final"]:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
--
Gitblit v1.9.1