From 3762d21300e1f3fa3e0cb1e67545227e6dcec3de Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 13 三月 2023 22:02:54 +0800
Subject: [PATCH] add streaming paraformer code

---
 funasr/models/encoder/sanm_encoder.py |   42 ++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 42 insertions(+), 0 deletions(-)

diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 0751a10..57890ef 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -347,6 +347,48 @@
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
 
+    def forward_chunk(self,
+                      xs_pad: torch.Tensor,
+                      ilens: torch.Tensor,
+                      cache: dict = None,
+                      ctc: CTC = None,
+                      ):
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        else:
+            xs_pad = self.embed.forward_chunk(xs_pad, cache)
+
+        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, None, None, None, None)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), None, None
+        return xs_pad, ilens, None
+
     def gen_tf2torch_map_dict(self):
         tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
         tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf

--
Gitblit v1.9.1