From 3762d21300e1f3fa3e0cb1e67545227e6dcec3de Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 13 三月 2023 22:02:54 +0800
Subject: [PATCH] add streaming paraformer code

---
 funasr/modules/attention.py |   10 ++++++----
 1 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 6277005..31d5a87 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -347,15 +347,17 @@
             mask = torch.reshape(mask, (b, -1, 1))
             if mask_shfit_chunk is not None:
                 mask = mask * mask_shfit_chunk
+            inputs = inputs * mask
 
-        inputs = inputs * mask
         x = inputs.transpose(1, 2)
         x = self.pad_fn(x)
         x = self.fsmn_block(x)
         x = x.transpose(1, 2)
         x += inputs
         x = self.dropout(x)
-        return x * mask
+        if mask is not None:
+            x = x * mask
+        return x
 
     def forward_qkv(self, x):
         """Transform query, key and value.
@@ -505,7 +507,7 @@
             # print("in fsmn, cache is None, x", x.size())
 
             x = self.pad_fn(x)
-            if not self.training and t <= 1:
+            if not self.training:
                 cache = x
         else:
             # print("in fsmn, cache is not None, x", x.size())
@@ -513,7 +515,7 @@
             # if t < self.kernel_size:
             #     x = self.pad_fn(x)
             x = torch.cat((cache[:, :, 1:], x), dim=2)
-            x = x[:, :, -self.kernel_size:]
+            x = x[:, :, -(self.kernel_size+t-1):]
             # print("in fsmn, cache is not None, x_cat", x.size())
             cache = x
         x = self.fsmn_block(x)

--
Gitblit v1.9.1