From cbbb3007436bf7cefa45af9a5aad60c696da9732 Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期一, 03 四月 2023 20:01:28 +0800
Subject: [PATCH] Merge pull request #324 from alibaba-damo-academy/dev_lhn2

---
 funasr/models/decoder/sanm_decoder.py |   14 +++++++-------
 1 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index a36d95e..463918a 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -94,7 +94,7 @@
         if self.self_attn:
             if self.normalize_before:
                 tgt = self.norm2(tgt)
-            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x, _ = self.self_attn(tgt, tgt_mask)
             x = residual + self.dropout(x)
 
         if self.src_attn is not None:
@@ -399,7 +399,7 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, memory_mask, cache=c
             )
             new_cache.append(c_ret)
@@ -409,13 +409,13 @@
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, memory_mask, cache=c
                 )
                 new_cache.append(c_ret)
 
         for decoder in self.decoders3:
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
             )
 
@@ -1076,7 +1076,7 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=c
             )
             new_cache.append(c_ret)
@@ -1086,14 +1086,14 @@
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, None, cache=c
                 )
                 new_cache.append(c_ret)
 
         for decoder in self.decoders3:
 
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
             )
 

--
Gitblit v1.9.1