From 6be782d9fde7c6a490fbe4b3f22de3bfc7a69406 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 03 四月 2023 19:57:26 +0800
Subject: [PATCH] fix decoder cache

---
 funasr/models/decoder/sanm_decoder.py |   14 +++++++-------
 1 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index a36d95e..463918a 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -94,7 +94,7 @@
         if self.self_attn:
             if self.normalize_before:
                 tgt = self.norm2(tgt)
-            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x, _ = self.self_attn(tgt, tgt_mask)
             x = residual + self.dropout(x)
 
         if self.src_attn is not None:
@@ -399,7 +399,7 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, memory_mask, cache=c
             )
             new_cache.append(c_ret)
@@ -409,13 +409,13 @@
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, memory_mask, cache=c
                 )
                 new_cache.append(c_ret)
 
         for decoder in self.decoders3:
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
             )
 
@@ -1076,7 +1076,7 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=c
             )
             new_cache.append(c_ret)
@@ -1086,14 +1086,14 @@
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, None, cache=c
                 )
                 new_cache.append(c_ret)
 
         for decoder in self.decoders3:
 
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
             )
 

--
Gitblit v1.9.1