From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/paraformer/decoder.py | 409 ++++++++++++++++++++++++++++------------------------------
1 files changed, 197 insertions(+), 212 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index f08e97b..7edd91a 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -17,7 +17,10 @@
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+ MultiHeadedAttentionSANMDecoder,
+ MultiHeadedAttentionCrossAtt,
+)
class DecoderLayerSANM(torch.nn.Module):
@@ -69,7 +72,7 @@
if self.concat_after:
self.concat_linear1 = torch.nn.Linear(size + size, size)
self.concat_linear2 = torch.nn.Linear(size + size, size)
- self.reserve_attn=False
+ self.reserve_attn = False
self.attn_mat = []
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -116,7 +119,7 @@
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
-
+
def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
residual = tgt
tgt = self.norm1(tgt)
@@ -173,10 +176,11 @@
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
return x, tgt_mask, memory, memory_mask, cache
- def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ def forward_chunk(
+ self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+ ):
"""Compute decoded features.
Args:
@@ -224,6 +228,7 @@
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
+
def __init__(
self,
vocab_size: int,
@@ -303,7 +308,13 @@
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+ attention_heads,
+ attention_dim,
+ src_attention_dropout_rate,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
@@ -351,9 +362,9 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
- return_hidden: bool = False,
- return_both: bool= False,
chunk_mask: torch.Tensor = None,
+ return_hidden: bool = False,
+ return_both: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -374,7 +385,7 @@
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
@@ -383,16 +394,10 @@
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
hidden = self.after_norm(x)
@@ -407,12 +412,12 @@
def score(self, ys, state, x):
"""Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
- )
+ ys_mask = myutils.sequence_mask(
+ torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+ )[:, :, None]
+ logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
-
+
def forward_asf2(
self,
hs_pad: torch.Tensor,
@@ -430,7 +435,7 @@
tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
return attn_mat
-
+
def forward_asf6(
self,
hs_pad: torch.Tensor,
@@ -494,22 +499,24 @@
for i in range(self.att_layer_num):
decoder = self.decoders[i]
x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
- x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
- chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+ x,
+ memory,
+ fsmn_cache=fsmn_cache[i],
+ opt_cache=opt_cache[i],
+ chunk_size=cache["chunk_size"],
+ look_back=cache["decoder_chunk_look_back"],
)
if self.num_blocks - self.att_layer_num > 1:
for i in range(self.num_blocks - self.att_layer_num):
j = i + self.att_layer_num
decoder = self.decoders2[i]
- x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
+ x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
x, memory, fsmn_cache=fsmn_cache[j]
)
for decoder in self.decoders3:
- x, memory, _, _ = decoder.forward_chunk(
- x, memory
- )
+ x, memory, _, _ = decoder.forward_chunk(x, memory)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
@@ -581,21 +588,18 @@
return y, new_cache
+
class DecoderLayerSANMExport(torch.nn.Module):
- def __init__(
- self,
- model
- ):
+ def __init__(self, model):
super().__init__()
self.self_attn = model.self_attn
self.src_attn = model.src_attn
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
- self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
- self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
+ self.norm2 = model.norm2 if hasattr(model, "norm2") else None
+ self.norm3 = model.norm3 if hasattr(model, "norm3") else None
self.size = model.size
-
def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -614,9 +618,8 @@
x = self.norm3(x)
x = residual + self.src_attn(x, memory, memory_mask)
-
return x, tgt_mask, memory, memory_mask, cache
-
+
def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
residual = tgt
tgt = self.norm1(tgt)
@@ -636,45 +639,39 @@
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
class ParaformerSANMDecoderExport(torch.nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True,
- **kwargs
- ):
+ def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
from funasr.utils.torch_function import sequence_mask
-
+
self.model = model
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
+
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
-
-
+
for i, d in enumerate(self.model.decoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
self.model.decoders[i] = DecoderLayerSANMExport(d)
-
+
if self.model.decoders2 is not None:
for i, d in enumerate(self.model.decoders2):
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
self.model.decoders2[i] = DecoderLayerSANMExport(d)
-
+
for i, d in enumerate(self.model.decoders3):
self.model.decoders3[i] = DecoderLayerSANMExport(d)
-
+
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
-
+
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
@@ -682,9 +679,9 @@
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
+
return mask_3d_btd, mask_4d_bhlt
-
+
def forward(
self,
hs_pad: torch.Tensor,
@@ -694,31 +691,27 @@
return_hidden: bool = False,
return_both: bool = False,
):
-
+
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
+
x = tgt
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
if self.model.decoders2 is not None:
x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
x, tgt_mask, memory, memory_mask
)
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
hidden = self.after_norm(x)
# x = self.output_layer(x)
-
+
if self.output_layer is not None and return_hidden is False:
x = self.output_layer(hidden)
return x, ys_in_lens
@@ -726,7 +719,7 @@
x = self.output_layer(hidden)
return x, hidden, ys_in_lens
return hidden, ys_in_lens
-
+
def forward_asf2(
self,
hs_pad: torch.Tensor,
@@ -742,10 +735,12 @@
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
_, memory_mask = self.prepare_mask(memory_mask)
- tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+ tgt, tgt_mask, memory, memory_mask
+ )
attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
return attn_mat
-
+
def forward_asf6(
self,
hs_pad: torch.Tensor,
@@ -761,15 +756,25 @@
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
_, memory_mask = self.prepare_mask(memory_mask)
- tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
- tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
- tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
- tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
- tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
+ tgt, tgt_mask, memory, memory_mask
+ )
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
+ tgt, tgt_mask, memory, memory_mask
+ )
attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
return attn_mat
-
- '''
+
+ """
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
@@ -818,14 +823,12 @@
for d in range(cache_num)
})
return ret
- '''
-
+ """
+
+
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True, **kwargs):
+ def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
@@ -833,7 +836,7 @@
from funasr.utils.torch_function import sequence_mask
self.model = model
-
+
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
@@ -854,11 +857,11 @@
for i, d in enumerate(self.model.decoders3):
self.model.decoders3[i] = DecoderLayerSANMExport(d)
-
+
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
-
+
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
@@ -866,9 +869,9 @@
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
+
return mask_3d_btd, mask_4d_bhlt
-
+
def forward(
self,
hs_pad: torch.Tensor,
@@ -877,17 +880,17 @@
ys_in_lens: torch.Tensor,
*args,
):
-
+
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
+
x = tgt
out_caches = list()
for i, decoder in enumerate(self.model.decoders):
@@ -903,76 +906,75 @@
x, tgt_mask, memory, memory_mask, cache=in_cache
)
out_caches.append(out_cache)
- x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
x = self.after_norm(x)
x = self.output_layer(x)
-
+
return x, out_caches
-
+
def get_dummy_inputs(self, enc_size):
enc = torch.randn(2, 100, enc_size).type(torch.float32)
enc_len = torch.tensor([30, 100], dtype=torch.int32)
acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
cache = [
- torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
- dtype=torch.float32)
+ torch.zeros(
+ (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
+ dtype=torch.float32,
+ )
for _ in range(cache_num)
]
return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
-
+
def get_input_names(self):
cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
- return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
- + ['in_cache_%d' % i for i in range(cache_num)]
-
+ return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
+ "in_cache_%d" % i for i in range(cache_num)
+ ]
+
def get_output_names(self):
cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
- return ['logits', 'sample_ids'] \
- + ['out_cache_%d' % i for i in range(cache_num)]
-
+ return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
+
def get_dynamic_axes(self):
ret = {
- 'enc': {
- 0: 'batch_size',
- 1: 'enc_length'
+ "enc": {0: "batch_size", 1: "enc_length"},
+ "acoustic_embeds": {0: "batch_size", 1: "token_length"},
+ "enc_len": {
+ 0: "batch_size",
},
- 'acoustic_embeds': {
- 0: 'batch_size',
- 1: 'token_length'
+ "acoustic_embeds_len": {
+ 0: "batch_size",
},
- 'enc_len': {
- 0: 'batch_size',
- },
- 'acoustic_embeds_len': {
- 0: 'batch_size',
- },
-
}
cache_num = len(self.model.decoders)
- if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
cache_num += len(self.model.decoders2)
- ret.update({
- 'in_cache_%d' % d: {
- 0: 'batch_size',
+ ret.update(
+ {
+ "in_cache_%d"
+ % d: {
+ 0: "batch_size",
+ }
+ for d in range(cache_num)
}
- for d in range(cache_num)
- })
- ret.update({
- 'out_cache_%d' % d: {
- 0: 'batch_size',
+ )
+ ret.update(
+ {
+ "out_cache_%d"
+ % d: {
+ 0: "batch_size",
+ }
+ for d in range(cache_num)
}
- for d in range(cache_num)
- })
+ )
return ret
@@ -983,23 +985,24 @@
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
+
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- embeds_id: int = -1,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ embeds_id: int = -1,
):
super().__init__(
vocab_size=vocab_size,
@@ -1017,12 +1020,8 @@
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -1033,11 +1032,11 @@
self.attention_dim = attention_dim
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -1060,23 +1059,17 @@
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
memory = hs_pad
- memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
- memory.device
- )
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
- memory_mask = torch.nn.functional.pad(
- memory_mask, (0, padlen), "constant", False
- )
+ memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
# x = self.embed(tgt)
x = tgt
embeds_outputs = None
for layer_id, decoder in enumerate(self.decoders):
- x, tgt_mask, memory, memory_mask = decoder(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
if layer_id == self.embeds_id:
embeds_outputs = x
if self.normalize_before:
@@ -1090,16 +1083,19 @@
else:
return x, olens
+
@tables.register("decoder_classes", "ParaformerDecoderSANExport")
class ParaformerDecoderSANExport(torch.nn.Module):
- def __init__(self, model,
- max_seq_len=512,
- model_name='decoder',
- onnx: bool = True, ):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ model_name="decoder",
+ onnx: bool = True,
+ ):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
-
from funasr.utils.torch_function import sequence_mask
@@ -1107,19 +1103,18 @@
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
from funasr.models.transformer.decoder import DecoderLayerExport
from funasr.models.transformer.attention import MultiHeadedAttentionExport
-
+
for i, d in enumerate(self.model.decoders):
if isinstance(d.src_attn, MultiHeadedAttention):
d.src_attn = MultiHeadedAttentionExport(d.src_attn)
self.model.decoders[i] = DecoderLayerExport(d)
-
+
self.output_layer = model.output_layer
self.after_norm = model.after_norm
self.model_name = model_name
-
+
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
@@ -1127,9 +1122,9 @@
elif len(mask.shape) == 3:
mask_4d_bhlt = 1 - mask[:, None, :]
mask_4d_bhlt = mask_4d_bhlt * -10000.0
-
+
return mask_3d_btd, mask_4d_bhlt
-
+
def forward(
self,
hs_pad: torch.Tensor,
@@ -1137,72 +1132,62 @@
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
):
-
+
tgt = ys_in_pad
tgt_mask = self.make_pad_mask(ys_in_lens)
tgt_mask, _ = self.prepare_mask(tgt_mask)
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+
memory = hs_pad
memory_mask = self.make_pad_mask(hlens)
_, memory_mask = self.prepare_mask(memory_mask)
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
+
x = tgt
- x, tgt_mask, memory, memory_mask = self.model.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
x = self.after_norm(x)
x = self.output_layer(x)
-
+
return x, ys_in_lens
-
+
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
pre_acoustic_embeds = torch.randn(1, 1, enc_size)
cache_num = len(self.model.decoders) + len(self.model.decoders2)
cache = [
- torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ torch.zeros(
+ (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
+ )
for _ in range(cache_num)
]
return (tgt, memory, pre_acoustic_embeds, cache)
-
+
def is_optimizable(self):
return True
-
+
def get_input_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['tgt', 'memory', 'pre_acoustic_embeds'] \
- + ['cache_%d' % i for i in range(cache_num)]
-
+ return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
+
def get_output_names(self):
cache_num = len(self.model.decoders) + len(self.model.decoders2)
- return ['y'] \
- + ['out_cache_%d' % i for i in range(cache_num)]
-
+ return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
+
def get_dynamic_axes(self):
ret = {
- 'tgt': {
- 0: 'tgt_batch',
- 1: 'tgt_length'
+ "tgt": {0: "tgt_batch", 1: "tgt_length"},
+ "memory": {0: "memory_batch", 1: "memory_length"},
+ "pre_acoustic_embeds": {
+ 0: "acoustic_embeds_batch",
+ 1: "acoustic_embeds_length",
},
- 'memory': {
- 0: 'memory_batch',
- 1: 'memory_length'
- },
- 'pre_acoustic_embeds': {
- 0: 'acoustic_embeds_batch',
- 1: 'acoustic_embeds_length',
- }
}
cache_num = len(self.model.decoders) + len(self.model.decoders2)
- ret.update({
- 'cache_%d' % d: {
- 0: 'cache_%d_batch' % d,
- 2: 'cache_%d_length' % d
+ ret.update(
+ {
+ "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
+ for d in range(cache_num)
}
- for d in range(cache_num)
- })
+ )
return ret
-
\ No newline at end of file
--
Gitblit v1.9.1