From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/paraformer/decoder.py | 627 +++++++++++++++++++++++++++++++++++++++++++++++++++------
1 files changed, 561 insertions(+), 66 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index ad321e4..7edd91a 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -17,7 +17,10 @@
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+ MultiHeadedAttentionSANMDecoder,
+ MultiHeadedAttentionCrossAtt,
+)
class DecoderLayerSANM(torch.nn.Module):
@@ -69,7 +72,7 @@
if self.concat_after:
self.concat_linear1 = torch.nn.Linear(size + size, size)
self.concat_linear2 = torch.nn.Linear(size + size, size)
- self.reserve_attn=False
+ self.reserve_attn = False
self.attn_mat = []
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -116,7 +119,7 @@
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
-
+
def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
residual = tgt
tgt = self.norm1(tgt)
@@ -173,10 +176,11 @@
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
return x, tgt_mask, memory, memory_mask, cache
- def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ def forward_chunk(
+ self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+ ):
"""Compute decoded features.
Args:
@@ -224,6 +228,7 @@
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
+
def __init__(
self,
vocab_size: int,
@@ -303,7 +308,13 @@
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+ attention_heads,
+ attention_dim,
+ src_attention_dropout_rate,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
@@ -351,9 +362,9 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
- return_hidden: bool = False,
- return_both: bool= False,
chunk_mask: torch.Tensor = None,
+ return_hidden: bool = False,
+ return_both: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -374,7 +385,7 @@
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
@@ -383,16 +394,10 @@
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
hidden = self.after_norm(x)
@@ -407,12 +412,12 @@
def score(self, ys, state, x):
"""Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
- )
+ ys_mask = myutils.sequence_mask(
+ torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+ )[:, :, None]
+ logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
-
+
def forward_asf2(
self,
hs_pad: torch.Tensor,
@@ -430,7 +435,7 @@
tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
return attn_mat
-
+
def forward_asf6(
self,
hs_pad: torch.Tensor,
@@ -494,22 +499,24 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
- x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
- chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+ x,
+ memory,
+ fsmn_cache=fsmn_cache[i],
+ opt_cache=opt_cache[i],
+ chunk_size=cache["chunk_size"],
+ look_back=cache["decoder_chunk_look_back"],
)
if self.num_blocks - self.att_layer_num > 1:
for i in range(self.num_blocks - self.att_layer_num):
j = i + self.att_layer_num
decoder = self.decoders2[i]
- x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
+ x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
x, memory, fsmn_cache=fsmn_cache[j]
)
for decoder in self.decoders3:
- x, memory, _, _ = decoder.forward_chunk(
- x, memory
- )
+ x, memory, _, _ = decoder.forward_chunk(x, memory)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
@@ -582,6 +589,395 @@
return y, new_cache
+class DecoderLayerSANMExport(torch.nn.Module):
+
+ def __init__(self, model):
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.src_attn = model.src_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2 if hasattr(model, "norm2") else None
+ self.norm3 = model.norm3 if hasattr(model, "norm3") else None
+ self.size = model.size
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ if self.src_attn is not None:
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.src_attn(x, memory, memory_mask)
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+ def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ residual = x
+ x = self.norm3(x)
+ x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+ return attn_mat
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
+class ParaformerSANMDecoderExport(torch.nn.Module):
+ def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ return_hidden: bool = False,
+ return_both: bool = False,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
+ if self.model.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
+ hidden = self.after_norm(x)
+ # x = self.output_layer(x)
+
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(hidden)
+ return x, ys_in_lens
+ if return_both:
+ x = self.output_layer(hidden)
+ return x, hidden, ys_in_lens
+ return hidden, ys_in_lens
+
+ def forward_asf2(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ def forward_asf6(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ """
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ + ['cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['y'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'tgt': {
+ 0: 'tgt_batch',
+ 1: 'tgt_length'
+ },
+ 'memory': {
+ 0: 'memory_batch',
+ 1: 'memory_length'
+ },
+ 'pre_acoustic_embeds': {
+ 0: 'acoustic_embeds_batch',
+ 1: 'acoustic_embeds_length',
+ }
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update({
+ 'cache_%d' % d: {
+ 0: 'cache_%d_batch' % d,
+ 2: 'cache_%d_length' % d
+ }
+ for d in range(cache_num)
+ })
+ return ret
+ """
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
+class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
+ def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ self.model = model
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ *args,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ out_caches = list()
+ for i, decoder in enumerate(self.model.decoders):
+ in_cache = args[i]
+ x, tgt_mask, memory, memory_mask, out_cache = decoder(
+ x, tgt_mask, memory, memory_mask, cache=in_cache
+ )
+ out_caches.append(out_cache)
+ if self.model.decoders2 is not None:
+ for i, decoder in enumerate(self.model.decoders2):
+ in_cache = args[i + len(self.model.decoders)]
+ x, tgt_mask, memory, memory_mask, out_cache = decoder(
+ x, tgt_mask, memory, memory_mask, cache=in_cache
+ )
+ out_caches.append(out_cache)
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, out_caches
+
+ def get_dummy_inputs(self, enc_size):
+ enc = torch.randn(2, 100, enc_size).type(torch.float32)
+ enc_len = torch.tensor([30, 100], dtype=torch.int32)
+ acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
+ acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ cache = [
+ torch.zeros(
+ (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
+ dtype=torch.float32,
+ )
+ for _ in range(cache_num)
+ ]
+ return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
+ "in_cache_%d" % i for i in range(cache_num)
+ ]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ "enc": {0: "batch_size", 1: "enc_length"},
+ "acoustic_embeds": {0: "batch_size", 1: "token_length"},
+ "enc_len": {
+ 0: "batch_size",
+ },
+ "acoustic_embeds_len": {
+ 0: "batch_size",
+ },
+ }
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ ret.update(
+ {
+ "in_cache_%d"
+ % d: {
+ 0: "batch_size",
+ }
+ for d in range(cache_num)
+ }
+ )
+ ret.update(
+ {
+ "out_cache_%d"
+ % d: {
+ 0: "batch_size",
+ }
+ for d in range(cache_num)
+ }
+ )
+ return ret
+
+
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
"""
@@ -589,23 +985,24 @@
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
+
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- embeds_id: int = -1,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ embeds_id: int = -1,
):
super().__init__(
vocab_size=vocab_size,
@@ -623,12 +1020,8 @@
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -639,11 +1032,11 @@
self.attention_dim = attention_dim
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -666,23 +1059,17 @@
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
memory = hs_pad
- memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
- memory.device
- )
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
- memory_mask = torch.nn.functional.pad(
- memory_mask, (0, padlen), "constant", False
- )
+ memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
# x = self.embed(tgt)
x = tgt
embeds_outputs = None
for layer_id, decoder in enumerate(self.decoders):
- x, tgt_mask, memory, memory_mask = decoder(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
if layer_id == self.embeds_id:
embeds_outputs = x
if self.normalize_before:
@@ -696,3 +1083,111 @@
else:
return x, olens
+
+@tables.register("decoder_classes", "ParaformerDecoderSANExport")
+class ParaformerDecoderSANExport(torch.nn.Module):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ model_name="decoder",
+ onnx: bool = True,
+ ):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ self.model = model
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.transformer.decoder import DecoderLayerExport
+ from funasr.models.transformer.attention import MultiHeadedAttentionExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.src_attn, MultiHeadedAttention):
+ d.src_attn = MultiHeadedAttentionExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, ys_in_lens
+
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros(
+ (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
+ )
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ "tgt": {0: "tgt_batch", 1: "tgt_length"},
+ "memory": {0: "memory_batch", 1: "memory_length"},
+ "pre_acoustic_embeds": {
+ 0: "acoustic_embeds_batch",
+ 1: "acoustic_embeds_length",
+ },
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update(
+ {
+ "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
+ for d in range(cache_num)
+ }
+ )
+ return ret
--
Gitblit v1.9.1