From 3762d21300e1f3fa3e0cb1e67545227e6dcec3de Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 13 三月 2023 22:02:54 +0800
Subject: [PATCH] add streaming paraformer code

---
 funasr/models/decoder/sanm_decoder.py |   59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 59 insertions(+), 0 deletions(-)

diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index ab03f0b..0117430 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -947,6 +947,65 @@
         )
         return logp.squeeze(0), state
 
+    def forward_chunk(
+        self,
+        memory: torch.Tensor,
+        tgt: torch.Tensor,
+        cache: dict = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        x = tgt
+        if cache["decode_fsmn"] is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            new_cache = [None] * cache_layer_num
+        else:
+            new_cache = cache["decode_fsmn"]
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, None, memory, None, cache=new_cache[i]
+            )
+            new_cache[i] = c_ret
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                    x, None, memory, None, cache=new_cache[j]
+                )
+                new_cache[j] = c_ret
+
+        for decoder in self.decoders3:
+
+            x, tgt_mask, memory, memory_mask, _ = decoder(
+                x, None, memory, None, cache=None
+            )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+        cache["decode_fsmn"] = new_cache
+        return x
+
     def forward_one_step(
         self,
         tgt: torch.Tensor,

--
Gitblit v1.9.1