From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/sanm/decoder.py | 104 +++++++++++++++++++++++++++++++---------------------
1 files changed, 62 insertions(+), 42 deletions(-)
diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py
index 3575282..1a4fb26 100644
--- a/funasr/models/sanm/decoder.py
+++ b/funasr/models/sanm/decoder.py
@@ -13,13 +13,17 @@
from funasr.models.scama import utils as myutils
from funasr.models.transformer.decoder import BaseTransformerDecoder
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+ MultiHeadedAttentionSANMDecoder,
+ MultiHeadedAttentionCrossAtt,
+)
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
+
class DecoderLayerSANM(nn.Module):
"""Single decoder layer module.
@@ -151,10 +155,11 @@
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
return x, tgt_mask, memory, memory_mask, cache
- def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ def forward_chunk(
+ self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+ ):
"""Compute decoded features.
Args:
@@ -202,7 +207,7 @@
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
-
+
def __init__(
self,
vocab_size: int,
@@ -240,7 +245,7 @@
)
if attention_dim is None:
attention_dim = encoder_output_size
-
+
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
@@ -255,7 +260,7 @@
)
else:
raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
+
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
@@ -263,7 +268,7 @@
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
-
+
self.att_layer_num = att_layer_num
self.num_blocks = num_blocks
if sanm_shfit is None:
@@ -276,7 +281,10 @@
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
+ attention_heads,
+ attention_dim,
+ src_attention_dropout_rate,
+ encoder_output_size=encoder_output_size,
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
@@ -292,7 +300,10 @@
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ attention_dim,
+ self_attention_dropout_rate,
+ kernel_size,
+ sanm_shfit=sanm_shfit,
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -301,7 +312,7 @@
concat_after,
),
)
-
+
self.decoders3 = repeat(
1,
lambda lnum: DecoderLayerSANM(
@@ -321,8 +332,12 @@
attention_dim + encoder_output_size,
None,
None,
- PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
- adim=attention_dim),
+ PositionwiseFeedForwardDecoderSANM(
+ attention_dim + encoder_output_size,
+ linear_units,
+ dropout_rate,
+ adim=attention_dim,
+ ),
dropout_rate,
normalize_before,
concat_after,
@@ -334,7 +349,7 @@
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
-
+
def forward(
self,
hs_pad: torch.Tensor,
@@ -363,47 +378,54 @@
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+
memory = hs_pad
memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
memory_mask = memory_mask * chunk_mask
if tgt_mask.size(1) != memory_mask.size(1):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
+
x = self.embed(tgt)
-
+
if pre_acoustic_embeds is not None and self.concat_embeds:
x = torch.cat((x, pre_acoustic_embeds), dim=-1)
x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
-
+
olens = tgt_mask.sum(1)
return x, olens
-
- def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+
+ def score(
+ self,
+ ys,
+ state,
+ x,
+ x_mask=None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ ):
"""Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ ys_mask = myutils.sequence_mask(
+ torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+ )[:, :, None]
logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
- cache=state
+ ys.unsqueeze(0),
+ ys_mask,
+ x.unsqueeze(0),
+ memory_mask=x_mask,
+ pre_acoustic_embeds=pre_acoustic_embeds,
+ cache=state,
)
return logp.squeeze(0), state
-
+
def forward_one_step(
self,
tgt: torch.Tensor,
@@ -426,15 +448,15 @@
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
-
+
x = tgt[:, -1:]
tgt_mask = None
x = self.embed(x)
-
+
if pre_acoustic_embeds is not None and self.concat_embeds:
x = torch.cat((x, pre_acoustic_embeds), dim=-1)
x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-
+
if cache is None:
cache_layer_num = len(self.decoders)
if self.decoders2 is not None:
@@ -449,7 +471,7 @@
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
-
+
if self.num_blocks - self.att_layer_num >= 1:
for i in range(self.num_blocks - self.att_layer_num):
j = i + self.att_layer_num
@@ -459,12 +481,12 @@
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
-
+
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
x, tgt_mask, memory, None, cache=None
)
-
+
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
@@ -472,7 +494,5 @@
if self.output_layer is not None:
y = self.output_layer(y)
y = torch.log_softmax(y, dim=-1)
-
+
return y, new_cache
-
-
\ No newline at end of file
--
Gitblit v1.9.1