From 2e769fb36ce88dabfa984e8b81e8cb1c90799c95 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 07 四月 2023 15:54:09 +0800
Subject: [PATCH] Merge branch 'main' into dev_cmz2
---
funasr/models/decoder/sanm_decoder.py | 13 ++++++-------
1 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index 3bfcffc..463918a 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -104,7 +104,6 @@
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
return x, tgt_mask, memory, memory_mask, cache
def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -400,7 +399,7 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
@@ -410,13 +409,13 @@
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
@@ -1077,7 +1076,7 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
@@ -1087,14 +1086,14 @@
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
--
Gitblit v1.9.1