From b18f7d121f2f17df8bf2d0c2bbb223bc5ddbcc0f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 五月 2023 16:11:22 +0800
Subject: [PATCH] docs

---
 funasr/models/encoder/sanm_encoder.py |   19 +++++++------------
 1 files changed, 7 insertions(+), 12 deletions(-)

diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 7d84ad5..da67586 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -27,9 +27,10 @@
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
 from funasr.models.ctc import CTC
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -354,18 +355,9 @@
     def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
         if len(cache) == 0:
             return feats
-        # process last chunk
         cache["feats"] = to_device(cache["feats"], device=feats.device)
         overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        if cache["is_final"]:
-            cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
-            if not cache["last_chunk"]:
-               padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
-               overlap_feats = overlap_feats.transpose(1, 2)
-               overlap_feats = F.pad(overlap_feats, (0, padding_length))
-               overlap_feats = overlap_feats.transpose(1, 2)
-        else:
-            cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
         return overlap_feats
 
     def forward_chunk(self,
@@ -379,7 +371,10 @@
             xs_pad = xs_pad
         else:
             xs_pad = self.embed(xs_pad, cache)
-        xs_pad = self._add_overlap_chunk(xs_pad, cache)
+        if cache["tail_chunk"]:
+            xs_pad = to_device(cache["feats"], device=xs_pad.device)
+        else:
+            xs_pad = self._add_overlap_chunk(xs_pad, cache)
         encoder_outs = self.encoders0(xs_pad, None, None, None, None)
         xs_pad, masks = encoder_outs[0], encoder_outs[1]
         intermediate_outs = []

--
Gitblit v1.9.1