From 97a689d65da434345a641a909f13b78e5690c86b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 18 五月 2023 19:35:08 +0800
Subject: [PATCH] Merge pull request #526 from alibaba-damo-academy/dev_infer
---
funasr/models/encoder/sanm_encoder.py | 14 +++-----------
1 files changed, 3 insertions(+), 11 deletions(-)
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a68011..da67586 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -27,9 +27,10 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -354,18 +355,9 @@
def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
- # process last chunk
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- if cache["is_final"]:
- cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
- if not cache["last_chunk"]:
- padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
- overlap_feats = overlap_feats.transpose(1, 2)
- overlap_feats = F.pad(overlap_feats, (0, padding_length))
- overlap_feats = overlap_feats.transpose(1, 2)
- else:
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
return overlap_feats
def forward_chunk(self,
--
Gitblit v1.9.1