From 836d57bb6c08c76dada384d93ca0ee3cc5374f48 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 20 十二月 2023 17:03:23 +0800
Subject: [PATCH] update seaco paraformer
---
funasr/models/paraformer/decoder.py | 58 +++++++++++++++++++++++++++++++++++++---------------------
1 files changed, 37 insertions(+), 21 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 3fe9d19..f59ce4d 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -68,6 +68,8 @@
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
+ self.reserve_attn=False
+ self.attn_mat = []
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
"""Compute decoded features.
@@ -104,8 +106,13 @@
residual = x
if self.normalize_before:
x = self.norm3(x)
-
- x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+ if self.reserve_attn:
+ x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+ self.attn_mat.append(attn_mat)
+ else:
+ x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
+ x = residual + self.dropout(x_src_attn)
+ # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
@@ -213,6 +220,7 @@
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
+ wo_input_layer: bool = False,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
@@ -239,22 +247,24 @@
)
attention_dim = encoder_output_size
-
- if input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(vocab_size, attention_dim),
- # pos_enc_class(attention_dim, positional_dropout_rate),
- )
- elif input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(vocab_size, attention_dim),
- torch.nn.LayerNorm(attention_dim),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- pos_enc_class(attention_dim, positional_dropout_rate),
- )
+ if wo_input_layer:
+ self.embed = None
else:
- raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ # pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
@@ -324,6 +334,8 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
+ return_hidden: bool = False,
+ return_both: bool= False,
chunk_mask: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -365,12 +377,16 @@
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
- x = self.after_norm(x)
- if self.output_layer is not None:
- x = self.output_layer(x)
+ hidden = self.after_norm(x)
olens = tgt_mask.sum(1)
- return x, olens
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(hidden)
+ return x, olens
+ if return_both:
+ x = self.output_layer(hidden)
+ return x, hidden, olens
+ return hidden, olens
def score(self, ys, state, x):
"""Score."""
--
Gitblit v1.9.1