From 6be782d9fde7c6a490fbe4b3f22de3bfc7a69406 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 03 四月 2023 19:57:26 +0800
Subject: [PATCH] fix decoder cache
---
funasr/models/decoder/sanm_decoder.py | 14 +++++++-------
1 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index a36d95e..463918a 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -94,7 +94,7 @@
if self.self_attn:
if self.normalize_before:
tgt = self.norm2(tgt)
- x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x, _ = self.self_attn(tgt, tgt_mask)
x = residual + self.dropout(x)
if self.src_attn is not None:
@@ -399,7 +399,7 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
@@ -409,13 +409,13 @@
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
@@ -1076,7 +1076,7 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
@@ -1086,14 +1086,14 @@
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
- x, tgt_mask, memory, memory_mask, c_ret = decoder(
+ x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
- x, tgt_mask, memory, memory_mask, _ = decoder(
+ x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
--
Gitblit v1.9.1